From c7b87ed0b8a52f68abb5dcbb01eb68a18e43e3a8 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Thu, 29 Jun 2023 15:07:29 -0400 Subject: [PATCH] Mem (2/3): migrate core to new DAL --- covalent/__init__.py | 6 +- covalent/_api/__init__.py | 19 + covalent/_api/apiclient.py | 105 +++ covalent/_dispatcher_plugins/local.py | 521 ++++++++--- covalent/_results_manager/results_manager.py | 495 +++++++--- covalent/_results_manager/wait.py | 2 + covalent/_workflow/transport_graph_ops.py | 204 ----- covalent/triggers/base.py | 41 +- covalent_dispatcher/.gitignore | 1 - covalent_dispatcher/__init__.py | 2 +- covalent_dispatcher/_cli/service.py | 1 + covalent_dispatcher/_core/__init__.py | 3 +- covalent_dispatcher/_core/data_manager.py | 468 ++++------ .../_core/data_modules/asset_manager.py | 93 ++ .../_core/data_modules/dispatch.py | 70 ++ .../_core/data_modules/electron.py | 114 +++ .../_core/data_modules/graph.py | 106 +++ .../_core/data_modules/importer.py | 152 +++ .../_core/data_modules/job_manager.py | 8 +- .../_core/data_modules/lattice.py | 42 + .../_core/data_modules/utils.py | 35 + covalent_dispatcher/_core/dispatcher.py | 617 ++++++++----- .../_core/dispatcher_modules/__init__.py | 19 + .../_core/dispatcher_modules/caches.py | 105 +++ .../_core/dispatcher_modules/store.py | 70 ++ covalent_dispatcher/_core/execution.py | 57 +- covalent_dispatcher/_core/runner.py | 279 ++---- .../_core/runner_modules/cancel.py | 150 +++ .../_core/runner_modules/executor_proxy.py | 59 +- .../_core/runner_modules/jobs.py | 130 +++ .../_core/runner_modules/utils.py | 47 + covalent_dispatcher/_db/load.py | 227 ----- covalent_dispatcher/_service/app.py | 398 ++++---- covalent_dispatcher/_service/assets.py | 436 +++++++++ .../_service}/heartbeat.py | 50 - covalent_dispatcher/_service/models.py | 117 +++ covalent_dispatcher/entry_point.py | 102 ++- covalent_ui/api/main.py | 2 +- covalent_ui/api/v1/data_layer/electron_dal.py | 3 +- covalent_ui/api/v1/data_layer/lattice_dal.py | 8 +- .../api/v1/database/schema/electron.py | 4 +- .../api/v1/database/schema/lattices.py | 10 +- covalent_ui/api/v1/models/lattices_model.py | 4 +- .../v1/routes/end_points/electron_routes.py | 84 +- .../api/v1/routes/end_points/lattice_route.py | 25 +- covalent_ui/api/v1/routes/routes.py | 7 +- covalent_ui/api/v1/utils/file_handle.py | 9 + tests/__init__.py | 19 + tests/covalent_dispatcher_tests/__init__.py | 19 + .../_cli/__init__.py | 19 + .../_core/__init__.py | 19 + .../_core/data_manager_test.py | 648 ++++++------- .../asset_manager_db_integration_test.py | 169 ++++ .../_core/data_modules/dispatch_test.py | 73 ++ .../_core/data_modules/graph_test.py | 98 ++ .../_core/data_modules/importer_test.py | 124 +++ .../_core/data_modules/job_manager_test.py | 10 +- .../_core/data_modules/lattice_query_test.py | 45 + .../_core/dispatcher_db_integration_test.py | 326 +++++++ .../_core/dispatcher_test.py | 866 +++++++++++------- .../_core/execution_test.py | 325 ++----- .../_core/runner_db_integration_test.py | 127 +++ .../_core/runner_modules/cancel_test.py | 106 +++ .../_core/runner_modules/jobs_test.py | 103 +++ .../_core/runner_test.py | 284 ++---- .../_db/load_test.py | 163 ---- .../_service/app_test.py | 264 ++++-- .../_service/assets_test.py | 735 +++++++++++++++ .../entry_point_test.py | 115 ++- tests/covalent_tests/__init__.py | 19 + .../dispatcher_plugins/__init__.py | 19 + .../dispatcher_plugins/local_test.py | 568 ++++++++++-- .../results_manager_tests/__init__.py | 19 + .../results_manager_test.py | 327 +++++-- tests/covalent_tests/triggers/base_test.py | 49 +- .../workflow/transport_graph_ops_test.py | 252 ----- tests/functional_tests/__init__.py | 19 + tests/functional_tests/file_transfer_test.py | 4 +- tests/functional_tests/local_executor_test.py | 32 + .../functional_tests/results_manager_test.py | 81 ++ tests/functional_tests/triggers_test.py | 2 +- .../workflow_cancellation_test.py | 10 +- tests/functional_tests/workflow_stack_test.py | 40 +- tests/load_tests/locustfiles/basic.py | 8 +- tests/load_tests/workflows/horizontal.py | 20 + tests/stress_tests/benchmarks/__init__.py | 19 + 86 files changed, 7990 insertions(+), 3662 deletions(-) create mode 100644 covalent/_api/__init__.py create mode 100644 covalent/_api/apiclient.py delete mode 100644 covalent/_workflow/transport_graph_ops.py create mode 100644 covalent_dispatcher/_core/data_modules/asset_manager.py create mode 100644 covalent_dispatcher/_core/data_modules/dispatch.py create mode 100644 covalent_dispatcher/_core/data_modules/electron.py create mode 100644 covalent_dispatcher/_core/data_modules/graph.py create mode 100644 covalent_dispatcher/_core/data_modules/importer.py create mode 100644 covalent_dispatcher/_core/data_modules/lattice.py create mode 100644 covalent_dispatcher/_core/data_modules/utils.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/__init__.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/caches.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/store.py create mode 100644 covalent_dispatcher/_core/runner_modules/cancel.py create mode 100644 covalent_dispatcher/_core/runner_modules/jobs.py create mode 100644 covalent_dispatcher/_core/runner_modules/utils.py delete mode 100644 covalent_dispatcher/_db/load.py create mode 100644 covalent_dispatcher/_service/assets.py rename {covalent_ui => covalent_dispatcher/_service}/heartbeat.py (60%) create mode 100644 covalent_dispatcher/_service/models.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py delete mode 100644 tests/covalent_dispatcher_tests/_db/load_test.py create mode 100644 tests/covalent_dispatcher_tests/_service/assets_test.py delete mode 100644 tests/covalent_tests/workflow/transport_graph_ops_test.py create mode 100644 tests/functional_tests/results_manager_test.py diff --git a/covalent/__init__.py b/covalent/__init__.py index 613487d1ce..439fc96bde 100644 --- a/covalent/__init__.py +++ b/covalent/__init__.py @@ -29,7 +29,11 @@ from ._dispatcher_plugins import local_redispatch as redispatch # nopycln: import from ._dispatcher_plugins import stop_triggers # nopycln: import from ._file_transfer import strategies as fs_strategies # nopycln: import -from ._results_manager.results_manager import cancel, get_result, sync # nopycln: import +from ._results_manager.results_manager import ( # nopycln: import + cancel, + get_result, + get_result_manager, +) from ._shared_files.config import get_config, reload_config, set_config # nopycln: import from ._shared_files.util_classes import RESULT_STATUS as status # nopycln: import from ._workflow import ( # nopycln: import diff --git a/covalent/_api/__init__.py b/covalent/_api/__init__.py new file mode 100644 index 0000000000..9d1b05526a --- /dev/null +++ b/covalent/_api/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/covalent/_api/apiclient.py b/covalent/_api/apiclient.py new file mode 100644 index 0000000000..aae20d0fa7 --- /dev/null +++ b/covalent/_api/apiclient.py @@ -0,0 +1,105 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + + +"""API client""" + +import json +import os +from typing import Dict + +import requests +from requests.adapters import HTTPAdapter + + +class CovalentAPIClient: + """Thin wrapper around Requests to centralize error handling.""" + + def __init__(self, dispatcher_addr: str, adapter: HTTPAdapter = None, auto_raise: bool = True): + self.dispatcher_addr = dispatcher_addr + self.adapter = adapter + self.auto_raise = auto_raise + + def get(self, endpoint: str, **kwargs): + headers = CovalentAPIClient.get_extra_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.get(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + def put(self, endpoint: str, **kwargs): + headers = CovalentAPIClient.get_extra_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.put(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + def post(self, endpoint: str, **kwargs): + headers = CovalentAPIClient.get_extra_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.post(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + @classmethod + def get_extra_headers(headers: Dict) -> Dict: + # This is expected to be a JSONified dictionary + data = os.environ.get("COVALENT_EXTRA_HEADERS") + if data: + return json.loads(data) + else: + return {} diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 02d6dc177c..0b9c498d72 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -18,18 +18,28 @@ # # Relief from the License may be granted by purchasing a commercial license. -import json +import tempfile from copy import deepcopy from functools import wraps +from pathlib import Path from typing import Callable, Dict, List, Optional, Union -import requests +from furl import furl -from .._results_manager import wait +from .._api.apiclient import CovalentAPIClient as APIClient from .._results_manager.result import Result -from .._results_manager.results_manager import get_result +from .._results_manager.results_manager import get_result, get_result_manager +from .._serialize.result import ( + extract_assets, + merge_response_manifest, + serialize_result, + strip_local_uris, +) from .._shared_files import logger from .._shared_files.config import get_config +from .._shared_files.schemas.asset import AssetSchema +from .._shared_files.schemas.result import ResultSchema +from .._shared_files.utils import copy_file_locally, format_server_url from .._workflow.lattice import Lattice from ..triggers import BaseTrigger from .base import BaseDispatcher @@ -37,36 +47,55 @@ app_log = logger.app_log log_stack_info = logger.log_stack_info +dispatch_cache_dir = Path(get_config("sdk.dispatch_cache_dir")) +dispatch_cache_dir.mkdir(parents=True, exist_ok=True) -def get_redispatch_request_body( + +def get_redispatch_request_body_v2( dispatch_id: str, - new_args: Optional[List] = None, - new_kwargs: Optional[Dict] = None, - replace_electrons: Optional[Dict[str, Callable]] = None, - reuse_previous_results: bool = False, -) -> Dict: - """Get request body for re-dispatching a workflow.""" - if new_args is None: - new_args = [] - if new_kwargs is None: - new_kwargs = {} + staging_dir: str, + new_args: List, + new_kwargs: Dict, + replace_electrons: Optional[Dict[str, Callable]], + dispatcher_addr: str = None, +) -> ResultSchema: + rm = get_result_manager(dispatch_id, dispatcher_addr=dispatcher_addr, wait=True) + manifest = ResultSchema.parse_obj(rm._manifest) + + # If no changes to inputs or electron, just retry the dispatch + if not new_args and not new_kwargs and not replace_electrons: + manifest.reset_metadata() + app_log.debug("Resubmitting manifest only") + return manifest + + # In all other cases we need to rebuild the graph + rm.download_lattice_asset("workflow_function") + rm.download_lattice_asset("workflow_function_string") + rm.load_lattice_asset("workflow_function") + rm.load_lattice_asset("workflow_function_string") + if replace_electrons is None: replace_electrons = {} - if new_args or new_kwargs: - res = get_result(dispatch_id) - lat = res.lattice - lat.build_graph(*new_args, **new_kwargs) - json_lattice = lat.serialize_to_json() - else: - json_lattice = None - updates = {k: v.electron_object.as_transportable_dict for k, v in replace_electrons.items()} - return { - "json_lattice": json_lattice, - "dispatch_id": dispatch_id, - "electron_updates": updates, - "reuse_previous_results": reuse_previous_results, - } + lat = rm.result_object.lattice + + if replace_electrons: + lat._replace_electrons = replace_electrons + + # If lattice inputs are not supplied, retrieve them from the previous dispatch + if not new_args and not new_kwargs: + rm.download_lattice_asset("inputs") + rm.load_lattice_asset("inputs") + res_obj = rm.result_object + inputs = res_obj.inputs.get_deserialized() + new_args = inputs["args"] + new_kargs = inputs["kwargs"] + + lat.build_graph(*new_args, **new_kwargs) + if replace_electrons: + del lat.__dict__["_replace_electrons"] + + return serialize_result(Result(lat), staging_dir) class LocalDispatcher(BaseDispatcher): @@ -79,6 +108,7 @@ class LocalDispatcher(BaseDispatcher): def dispatch( orig_lattice: Lattice, dispatcher_addr: str = None, + *, disable_run: bool = False, ) -> Callable: """ @@ -90,20 +120,23 @@ def dispatch( Args: orig_lattice: The lattice/workflow to send to the dispatcher server. - dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. - disable_run: Whether to disable running the workflow and rather just save it on Covalent's server for later execution + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. Returns: Wrapper function which takes the inputs of the workflow as arguments """ - if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + multistage = get_config("sdk.multistage_dispatch") == "true" + + # Extract triggers here + if "triggers" in orig_lattice.metadata: + triggers_data = orig_lattice.metadata.pop("triggers") + else: + triggers_data = None + + if not disable_run: + # Determine whether to disable first run based on trigger_data + disable_run = triggers_data is not None @wraps(orig_lattice) def wrapper(*args, **kwargs) -> str: @@ -119,8 +152,61 @@ def wrapper(*args, **kwargs) -> str: The dispatch id of the workflow. """ - # To access the disable_run passed to the dispatch function - nonlocal disable_run + if multistage: + dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)( + *args, **kwargs + ) + else: + dispatch_id = LocalDispatcher.submit(orig_lattice, dispatcher_addr)( + *args, **kwargs + ) + + if triggers_data: + LocalDispatcher.register_triggers(triggers_data, dispatch_id) + + if not disable_run: + return LocalDispatcher.start(dispatch_id, dispatcher_addr) + else: + return dispatch_id + + return wrapper + + @staticmethod + def submit( + orig_lattice: Lattice, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. + + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + @wraps(orig_lattice) + def wrapper(*args, **kwargs) -> str: + """ + Send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The dispatch id of the workflow. + """ if not isinstance(orig_lattice, Lattice): message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." @@ -133,42 +219,40 @@ def wrapper(*args, **kwargs) -> str: # Serialize the transport graph to JSON json_lattice = lattice.serialize_to_json() + endpoint = "/api/v1/dispatch/submit" + r = APIClient(dispatcher_addr).post(endpoint, data=json_lattice) + r.raise_for_status() + return r.content.decode("utf-8").strip().replace('"', "") - # Extract triggers here - json_lattice = json.loads(json_lattice) - triggers_data = json_lattice["metadata"].pop("triggers") - - if not disable_run: - # Determine whether to disable first run based on trigger_data - disable_run = triggers_data is not None - - json_lattice = json.dumps(json_lattice) + return wrapper - submit_dispatch_url = f"{dispatcher_addr}/api/submit" + @staticmethod + def start( + dispatch_id: str, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. - lattice_dispatch_id = None - try: - r = requests.post( - submit_dispatch_url, - data=json_lattice, - params={"disable_run": disable_run}, - timeout=5, - ) - r.raise_for_status() - lattice_dispatch_id = r.content.decode("utf-8").strip().replace('"', "") - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - return + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. - if not disable_run or triggers_data is None: - return lattice_dispatch_id + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. - LocalDispatcher.register_triggers(triggers_data, lattice_dispatch_id) + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ - return lattice_dispatch_id + if dispatcher_addr is None: + dispatcher_addr = format_server_url() - return wrapper + endpoint = f"/api/v1/dispatch/start/{dispatch_id}" + r = APIClient(dispatcher_addr).put(endpoint) + r.raise_for_status() + return r.content.decode("utf-8").strip().replace('"', "") @staticmethod def dispatch_sync( @@ -184,19 +268,14 @@ def dispatch_sync( Args: orig_lattice: The lattice/workflow to send to the dispatcher server. - dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. Returns: Wrapper function which takes the inputs of the workflow as arguments. """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() @wraps(lattice) def wrapper(*args, **kwargs) -> Result: @@ -214,7 +293,7 @@ def wrapper(*args, **kwargs) -> Result: return get_result( LocalDispatcher.dispatch(lattice, dispatcher_addr)(*args, **kwargs), - wait=wait.EXTREME, + wait=True, ) return wrapper @@ -225,7 +304,6 @@ def redispatch( dispatcher_addr: str = None, replace_electrons: Dict[str, Callable] = None, reuse_previous_results: bool = False, - is_pending: bool = False, ) -> Callable: """ Wrapping the dispatching functionality to allow input passing and server address specification. @@ -241,45 +319,17 @@ def redispatch( """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() if replace_electrons is None: replace_electrons = {} - def func(*new_args, **new_kwargs): - """ - Prepare the redispatch request body and redispatch the workflow. - - Args: - *args: The inputs of the workflow. - **kwargs: The keyword arguments of the workflow. - - Returns: - The result of the executed workflow. - - """ - body = get_redispatch_request_body( - dispatch_id, new_args, new_kwargs, replace_electrons, reuse_previous_results - ) - redispatch_url = f"{dispatcher_addr}/api/redispatch" - try: - r = requests.post( - redispatch_url, json=body, params={"is_pending": is_pending}, timeout=5 - ) - r.raise_for_status() - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - return - - return r.content.decode("utf-8").strip().replace('"', "") - - return func + return LocalDispatcher.register_redispatch( + dispatch_id=dispatch_id, + dispatcher_addr=dispatcher_addr, + replace_electrons=replace_electrons, + reuse_previous_results=reuse_previous_results, + ) @staticmethod def register_triggers(triggers_data: List[Dict], dispatch_id: str) -> None: @@ -328,9 +378,252 @@ def stop_triggers( if isinstance(dispatch_ids, str): dispatch_ids = [dispatch_ids] - r = requests.post(stop_triggers_url, json=dispatch_ids) + endpoint = "/api/triggers/stop_observe" + r = APIClient(triggers_server_addr).post(endpoint, json=dispatch_ids) r.raise_for_status() app_log.debug("Triggers for following dispatch_ids have stopped observing:") for d_id in dispatch_ids: app_log.debug(d_id) + + @staticmethod + def register( + orig_lattice: Lattice, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. + + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + @wraps(orig_lattice) + def wrapper(*args, **kwargs) -> str: + """ + Send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The dispatch id of the workflow. + """ + + if not isinstance(orig_lattice, Lattice): + message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." + app_log.error(message) + raise TypeError(message) + + lattice = deepcopy(orig_lattice) + + lattice.build_graph(*args, **kwargs) + + with tempfile.TemporaryDirectory() as tmp_dir: + manifest = LocalDispatcher.prepare_manifest(lattice, tmp_dir) + LocalDispatcher.register_manifest(manifest, dispatcher_addr) + + dispatch_id = manifest.metadata.dispatch_id + + path = dispatch_cache_dir / f"{dispatch_id}" + + with open(path, "w") as f: + f.write(manifest.json()) + + LocalDispatcher.upload_assets(manifest) + + return dispatch_id + + return wrapper + + @staticmethod + def register_redispatch( + dispatch_id: str, + dispatcher_addr: str = None, + replace_electrons: Dict[str, Callable] = None, + reuse_previous_results: bool = False, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing and server address specification. + + Args: + dispatch_id: The dispatch id of the workflow to re-dispatch. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + replace_electrons: A dictionary of electron names and the new electron to replace them with. + reuse_previous_results: Boolean value whether to reuse the results from the previous dispatch. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments. + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + def func(*new_args, **new_kwargs): + """ + Prepare the redispatch request body and redispatch the workflow. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The result of the executed workflow. + """ + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = get_redispatch_request_body_v2( + dispatch_id=dispatch_id, + staging_dir=staging_dir, + new_args=new_args, + new_kwargs=new_kwargs, + replace_electrons=replace_electrons, + dispatcher_addr=dispatcher_addr, + ) + + LocalDispatcher.register_derived_manifest( + manifest, + dispatch_id, + reuse_previous_results=reuse_previous_results, + dispatcher_addr=dispatcher_addr, + ) + + redispatch_id = manifest.metadata.dispatch_id + + path = dispatch_cache_dir / f"{redispatch_id}" + + with open(path, "w") as f: + f.write(manifest.json()) + + LocalDispatcher.upload_assets(manifest) + + return LocalDispatcher.start(redispatch_id, dispatcher_addr) + + return func + + @staticmethod + def prepare_manifest(lattice, storage_path) -> ResultSchema: + """Prepare a built-out lattice for submission""" + + result_object = Result(lattice) + return serialize_result(result_object, storage_path) + + @staticmethod + def register_manifest( + manifest: ResultSchema, + dispatcher_addr: Optional[str] = None, + parent_dispatch_id: Optional[str] = None, + push_assets: bool = True, + ) -> ResultSchema: + """Submits a manifest for registration. + + Returns: + Dictionary representation of manifest with asset remote_uris filled in + + Side effect: + If push_assets is False, the server will + automatically pull the task assets from the submitted asset URIs. + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + if push_assets: + stripped = strip_local_uris(manifest) + else: + stripped = manifest + + endpoint = "/api/v1/dispatch/register" + + if parent_dispatch_id: + endpoint = f"{endpoint}?parent_dispatch_id={parent_dispatch_id}" + + r = APIClient(dispatcher_addr).post(endpoint, data=stripped.json()) + r.raise_for_status() + + parsed_resp = ResultSchema.parse_obj(r.json()) + + return merge_response_manifest(manifest, parsed_resp) + + @staticmethod + def register_derived_manifest( + manifest: ResultSchema, + dispatch_id: str, + reuse_previous_results: bool = False, + dispatcher_addr: Optional[str] = None, + ) -> ResultSchema: + """Submits a derived manifest for registration. + + Returns: + Dictionary representation of manifest with asset remote_uris filled in + + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + # We don't yet support pulling assets for redispatch + stripped = strip_local_uris(manifest) + + endpoint = f"/api/v1/dispatch/register/{dispatch_id}" + + params = {"reuse_previous_results": reuse_previous_results} + r = APIClient(dispatcher_addr).post(endpoint, data=stripped.json(), params=params) + r.raise_for_status() + + parsed_resp = ResultSchema.parse_obj(r.json()) + + return merge_response_manifest(manifest, parsed_resp) + + @staticmethod + def upload_assets(manifest: ResultSchema): + assets = extract_assets(manifest) + LocalDispatcher._upload(assets) + + @staticmethod + def _upload(assets: List[AssetSchema]): + local_scheme_prefix = "file://" + total = len(assets) + for i, asset in enumerate(assets): + if not asset.remote_uri: + app_log.debug(f"Skipping asset {i+1} out of {total}") + continue + if asset.remote_uri.startswith(local_scheme_prefix): + copy_file_locally(asset.uri, asset.remote_uri) + else: + _upload_asset(asset.uri, asset.remote_uri) + app_log.debug(f"uploaded {i+1} out of {total} assets.") + + +def _upload_asset(local_uri, remote_uri): + scheme_prefix = "file://" + if local_uri.startswith(scheme_prefix): + local_path = local_uri[len(scheme_prefix) :] + else: + local_path = local_uri + + with open(local_path, "rb") as f: + files = {"asset_file": f} + app_log.debug(f"uploading to {remote_uri}") + f = furl(remote_uri) + scheme = f.scheme + host = f.host + port = f.port + dispatcher_addr = f"{scheme}://{host}:{port}" + endpoint = str(f.path) + api_client = APIClient(dispatcher_addr) + r = api_client.post(endpoint, files=files) + r.raise_for_status() diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 2a85777794..a4df3df11e 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -19,19 +19,32 @@ # Relief from the License may be granted by purchasing a commercial license. -import codecs +from __future__ import annotations + import contextlib import os -from typing import Dict, List, Optional, Union +from pathlib import Path +from typing import Dict, List, Optional -import cloudpickle as pickle -import requests +from furl import furl from requests.adapters import HTTPAdapter from urllib3.util import Retry +from .._api.apiclient import CovalentAPIClient +from .._serialize.common import load_asset +from .._serialize.electron import ASSET_FILENAME_MAP as ELECTRON_ASSET_FILENAMES +from .._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES +from .._serialize.lattice import ASSET_FILENAME_MAP as LATTICE_ASSET_FILENAMES +from .._serialize.lattice import ASSET_TYPES as LATTICE_ASSET_TYPES +from .._serialize.result import ASSET_FILENAME_MAP as RESULT_ASSET_FILENAMES +from .._serialize.result import ASSET_TYPES as RESULT_ASSET_TYPES +from .._serialize.result import deserialize_result from .._shared_files import logger from .._shared_files.config import get_config from .._shared_files.exceptions import MissingLatticeRecordError +from .._shared_files.schemas.asset import AssetSchema +from .._shared_files.schemas.result import ResultSchema +from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result from .wait import EXTREME @@ -39,52 +52,105 @@ log_stack_info = logger.log_stack_info -def get_result( - dispatch_id: str, wait: bool = False, dispatcher_addr: str = None, status_only: bool = False -) -> Result: +SDK_NODE_META_KEYS = { + "executor", + "executor_data", + "deps", + "call_before", + "call_after", +} + +SDK_LAT_META_KEYS = { + "executor", + "executor_data", + "workflow_executor", + "workflow_executor_data", + "deps", + "call_before", + "call_after", +} + +DEFERRED_KEYS = { + "output", + "value", + "result", +} + + +def _delete_result( + dispatch_id: str, + results_dir: str = None, + remove_parent_directory: bool = False, +) -> None: """ - Get the results of a dispatch from the Covalent server. + Internal function to delete the result. Args: dispatch_id: The dispatch id of the result. - wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. - status_only: If true, only returns result status, not the full result object, default is False. + results_dir: The directory where the results are stored in dispatch id named folders. + remove_parent_directory: Status of whether to delete the parent directory when removing the result. Returns: - The Result object from the Covalent server + None + Raises: + FileNotFoundError: If the result file is not found. """ - try: - result = _get_result_from_dispatcher( - dispatch_id, - wait, - dispatcher_addr, - status_only, - ) + if results_dir is None: + results_dir = os.environ.get("COVALENT_DATA_DIR") or get_config("dispatcher.results_dir") - if not status_only: - result = pickle.loads(codecs.decode(result["result"].encode(), "base64")) + import shutil - except MissingLatticeRecordError as ex: - app_log.warning( - f"Dispatch ID {dispatch_id} was not found in the database. Incorrect dispatch id." - ) + result_folder_path = os.path.join(results_dir, f"{dispatch_id}") - raise ex + if os.path.exists(result_folder_path): + shutil.rmtree(result_folder_path, ignore_errors=True) - except requests.exceptions.ConnectionError: - return None + with contextlib.suppress(OSError): + os.rmdir(results_dir) - return result + if remove_parent_directory: + shutil.rmtree(results_dir, ignore_errors=True) -def _get_result_from_dispatcher( +def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = None) -> str: + """ + Cancel a running dispatch. + + Args: + dispatch_id: The dispatch id of the dispatch to be cancelled. + task_ids: Optional, list of task ids to cancel within the workflow + dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. + + Returns: + Cancellation response + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + if task_ids is None: + task_ids = [] + + api_client = CovalentAPIClient(dispatcher_addr) + endpoint = "/api/v1/dispatch/cancel" + + if isinstance(task_ids, int): + task_ids = [task_ids] + + r = api_client.post(endpoint, json={"dispatch_id": dispatch_id, "task_ids": task_ids}) + return r.content.decode("utf-8").strip().replace('"', "") + + +# Multi-part + + +def _get_result_export_from_dispatcher( dispatch_id: str, wait: bool = False, - dispatcher_addr: str = None, status_only: bool = False, + dispatcher_addr: str = None, ) -> Dict: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -92,8 +158,8 @@ def _get_result_from_dispatcher( Args: dispatch_id: The dispatch id of the result. wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. status_only: If true, only returns result status, not the full result object, default is False. + dispatcher_addr: Dispatcher server address, defaults to the address set in covalent.config. Returns: The result object from the server. @@ -103,141 +169,318 @@ def _get_result_from_dispatcher( """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" + get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() retries = int(EXTREME) if wait else 5 adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - http = requests.Session() - http.mount("http://", adapter) - - result_url = f"{dispatcher_addr}/api/result/{dispatch_id}" - - try: - response = http.get( - result_url, - params={"wait": bool(int(wait)), "status_only": status_only}, - timeout=5, - ) - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - raise + api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) + endpoint = "/api/v1/dispatch/export/" + dispatch_id + response = api_client.get( + endpoint, + params={"wait": wait, "status_only": status_only}, + ) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() + export = response.json() + return export - return response.json() +# Function to download default assets +def _get_default_assets(rm: ResultManager): + for key in RESULT_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + rm.download_result_asset(key) + rm.load_result_asset(key) -def _delete_result( - dispatch_id: str, - results_dir: str = None, - remove_parent_directory: bool = False, -) -> None: - """ - Internal function to delete the result. + for key in LATTICE_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + rm.download_lattice_asset(key) + rm.load_lattice_asset(key) - Args: - dispatch_id: The dispatch id of the result. - results_dir: The directory where the results are stored in dispatch id named folders. - remove_parent_directory: Status of whether to delete the parent directory when removing the result. + tg = rm.result_object.lattice.transport_graph - Returns: - None + tg.lattice_metadata = rm.result_object.lattice.metadata + rm.result_object.lattice.__doc__ = rm.result_object.lattice.__dict__.pop("doc") - Raises: - FileNotFoundError: If the result file is not found. - """ + for key in ELECTRON_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + for node_id in tg._graph.nodes: + rm.download_node_asset(node_id, key) + rm.load_node_asset(node_id, key) - if results_dir is None: - results_dir = os.environ.get("COVALENT_DATA_DIR") or get_config("dispatcher.results_dir") - import shutil +# Functions for computing local URIs +def get_node_asset_path(results_dir: str, node_id: int, key: str): + filename = ELECTRON_ASSET_FILENAMES[key] + return results_dir + f"/node_{node_id}/{filename}" - result_folder_path = os.path.join(results_dir, f"{dispatch_id}") - if os.path.exists(result_folder_path): - shutil.rmtree(result_folder_path, ignore_errors=True) +def get_lattice_asset_path(results_dir: str, key: str): + filename = LATTICE_ASSET_FILENAMES[key] + return results_dir + f"/{filename}" - with contextlib.suppress(OSError): - os.rmdir(results_dir) - if remove_parent_directory: - shutil.rmtree(results_dir, ignore_errors=True) +def get_result_asset_path(results_dir: str, key: str): + filename = RESULT_ASSET_FILENAMES[key] + return results_dir + f"/{filename}" -def redispatch_result(result_object: Result, dispatcher: str = None) -> str: - """ - Function to redispatch the result as a new dispatch. +# Asset transfers - Args: - result_object: The result object to be redispatched. - dispatcher: The address to the dispatcher in the form of hostname:port, e.g. "localhost:8080". - Returns: - dispatch_id: The dispatch id of the new dispatch. - """ - result_object._lattice.metadata["dispatcher"] = ( - dispatcher or result_object.lattice.metadata["dispatcher"] - ) +def download_asset(remote_uri: str, local_path: str, chunk_size: int = 1024 * 1024): + local_scheme = "file" + if remote_uri.startswith(local_scheme): + copy_file_locally(remote_uri, f"file://{local_path}") + else: + f = furl(remote_uri) + scheme = f.scheme + host = f.host + port = f.port + dispatcher_addr = f"{scheme}://{host}:{port}" + endpoint = str(f.path) + api_client = CovalentAPIClient(dispatcher_addr) + r = api_client.get(endpoint, stream=True) + with open(local_path, "wb") as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + + +def _download_result_asset(manifest: dict, results_dir: str, key: str): + remote_uri = manifest["assets"][key]["remote_uri"] + local_path = get_result_asset_path(results_dir, key) + download_asset(remote_uri, local_path) + manifest["assets"][key]["uri"] = "file://" + local_path + + +def _download_lattice_asset(manifest: dict, results_dir: str, key: str): + lattice_assets = manifest["lattice"]["assets"] + remote_uri = lattice_assets[key]["remote_uri"] + local_path = get_lattice_asset_path(results_dir, key) + download_asset(remote_uri, local_path) + lattice_assets[key]["uri"] = "file://" + local_path + + +def _download_node_asset(manifest: dict, results_dir: str, node_id: int, key: str): + node = manifest["lattice"]["transport_graph"]["nodes"][node_id] + node_assets = node["assets"] + remote_uri = node_assets[key]["remote_uri"] + local_path = get_node_asset_path(results_dir, node_id, key) + download_asset(remote_uri, local_path) + node_assets[key]["uri"] = "file://" + local_path + + +def _load_result_asset(manifest: dict, key: str): + asset_meta = AssetSchema(**manifest["assets"][key]) + return load_asset(asset_meta, RESULT_ASSET_TYPES[key]) + + +def _load_lattice_asset(manifest: dict, key: str): + asset_meta = AssetSchema(**manifest["lattice"]["assets"][key]) + return load_asset(asset_meta, LATTICE_ASSET_TYPES[key]) + + +def _load_node_asset(manifest: dict, node_id: int, key: str): + node = manifest["lattice"]["transport_graph"]["nodes"][node_id] + asset_meta = AssetSchema(**node["assets"][key]) + return load_asset(asset_meta, ELECTRON_ASSET_TYPES[key]) + + +class ResultManager: + def __init__(self, manifest: ResultSchema, results_dir: str): + self.result_object = deserialize_result(manifest) + self._manifest = manifest.dict() + self._results_dir = results_dir + + def save(self, path: Optional[str] = None): + if not path: + path = os.path.join(self._results_dir, "manifest.json") + with open(path, "w") as f: + f.write(ResultSchema.parse_obj(self._manifest).json()) + + @staticmethod + def load(path: str, results_dir: str) -> "ResultManager": + with open(path, "r") as f: + manifest_json = f.read() + + return ResultManager(ResultSchema.parse_raw(manifest_json), results_dir) + + def download_result_asset(self, key: str): + _download_result_asset(self._manifest, self._results_dir, key) + + def download_lattice_asset(self, key: str): + _download_lattice_asset(self._manifest, self._results_dir, key) + + def download_node_asset(self, node_id: int, key: str): + _download_node_asset(self._manifest, self._results_dir, node_id, key) + + def load_result_asset(self, key: str): + data = _load_result_asset(self._manifest, key) + self.result_object.__dict__[f"_{key}"] = data + + def load_lattice_asset(self, key: str): + data = _load_lattice_asset(self._manifest, key) + if key in SDK_LAT_META_KEYS: + self.result_object.lattice.metadata[key] = data + else: + self.result_object.lattice.__dict__[key] = data + + def load_node_asset(self, node_id: int, key: str): + data = _load_node_asset(self._manifest, node_id, key) + tg = self.result_object.lattice.transport_graph + if key in SDK_NODE_META_KEYS: + node_meta = tg.get_node_value(node_id, "metadata") + node_meta[key] = data + else: + tg.set_node_value(node_id, key, data) + + @staticmethod + def from_dispatch_id( + dispatch_id: str, + results_dir: str, + wait: bool = False, + dispatcher_addr: str = None, + ) -> "ResultManager": + export = _get_result_export_from_dispatcher( + dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr + ) - return result_object.lattice._server_dispatch(result_object) + manifest = ResultSchema.parse_obj(export["result_export"]) + # sort the nodes + manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) -def sync( - dispatch_id: Optional[Union[List[str], str]] = None, -) -> None: - """ - Synchronization call. Returns when one or more dispatches have completed. + rm = ResultManager(manifest, results_dir) + result_object = rm.result_object + result_object._results_dir = results_dir + Path(results_dir).mkdir(parents=True, exist_ok=True) - Args: - dispatch_id: One or more dispatch IDs to wait for before returning. + # Create node subdirectories + for node_id in result_object.lattice.transport_graph._graph.nodes: + node_dir = results_dir + f"/node_{node_id}" + Path(node_dir).mkdir(exist_ok=True) - Returns: - None - """ + return rm - if isinstance(dispatch_id, str): - _get_result_from_dispatcher(dispatch_id, wait=True, status_only=True) - elif isinstance(dispatch_id, list): - for d in dispatch_id: - _get_result_from_dispatcher(d, wait=True, status_only=True) - else: - raise RuntimeError( - f"dispatch_id must be a string or a list. You passed a {type(dispatch_id)}." - ) +def get_result_manager(dispatch_id, results_dir=None, wait=False, dispatcher_addr=None): + if not results_dir: + results_dir = get_config("sdk.results_dir") + f"/{dispatch_id}" + return ResultManager.from_dispatch_id(dispatch_id, results_dir, wait, dispatcher_addr) -def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = None) -> str: + +def _get_result_multistage( + dispatch_id: str, + wait: bool = False, + dispatcher_addr: str = None, + status_only: bool = False, + results_dir: Optional[str] = None, + *, + workflow_output: bool = True, + intermediate_outputs: bool = True, + sublattice_results: bool = True, +) -> Result: """ - Cancel a running dispatch. + Get the results of a dispatch from a file. Args: - dispatch_id: The dispatch id of the dispatch to be cancelled. - task_ids: Optional, list of task ids to cancel within the workflow - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. + dispatch_id: The dispatch id of the result. + wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. Returns: - Cancellation response + The result from the file. + """ - if dispatcher_addr is None: - dispatcher_addr = ( - get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")) + try: + if status_only: + return _get_result_export_from_dispatcher( + dispatch_id=dispatch_id, + wait=wait, + status_only=status_only, + dispatcher_addr=dispatcher_addr, + ) + else: + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) + _get_default_assets(rm) + + if workflow_output: + rm.download_result_asset("result") + rm.load_result_asset("result") + + if intermediate_outputs: + tg = rm.result_object.lattice.transport_graph + for node_id in tg._graph.nodes: + rm.download_node_asset(node_id, "output") + rm.load_node_asset(node_id, "output") + + # Fetch sublattice result objects recursively + tg = rm.result_object.lattice.transport_graph + for node_id in tg._graph.nodes: + sub_dispatch_id = tg.get_node_value(node_id, "sub_dispatch_id") + if sublattice_results and sub_dispatch_id: + sub_result = _get_result_multistage( + sub_dispatch_id, + wait, + dispatcher_addr, + status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + ) + tg.set_node_value(node_id, "sublattice_result", sub_result) + else: + tg.set_node_value(node_id, "sublattice_result", None) + + except MissingLatticeRecordError as ex: + app_log.warning( + f"Dispatch ID {dispatch_id} was not found in the database. Incorrect dispatch id." ) - if task_ids is None: - task_ids = [] + raise ex - url = f"http://{dispatcher_addr}/api/cancel" + return rm.result_object - if isinstance(task_ids, int): - task_ids = [task_ids] - r = requests.post(url, json={"dispatch_id": dispatch_id, "task_ids": task_ids}) - r.raise_for_status() - return r.content.decode("utf-8").strip().replace('"', "") +def get_result( + dispatch_id: str, + wait: bool = False, + dispatcher_addr: str = None, + status_only: bool = False, + *, + results_dir: Optional[str] = None, + workflow_output: bool = True, + intermediate_outputs: bool = True, + sublattice_results: bool = True, +) -> Result: + """ + Get the results of a dispatch. + + Args: + dispatch_id: The dispatch id of the result. + wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. + dispatcher_addr: Dispatcher server address. Defaults to the address set in Covalent's config. + status_only: If true, only returns result status, not the full result object. Default is False. + workflow_output: Whether to return the workflow output. Defaults to True. + intermediate_outputs: Whether to return all intermediate outputs in the compute graph. Defaults to True. + sublattice_results: Whether to recursively retrieve sublattice results. Default is True. + + Returns: + The Result object from the Covalent server + + """ + + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + ) diff --git a/covalent/_results_manager/wait.py b/covalent/_results_manager/wait.py index cbbb830227..1afc1c6836 100644 --- a/covalent/_results_manager/wait.py +++ b/covalent/_results_manager/wait.py @@ -1,3 +1,5 @@ +# Copyright 2021 Agnostiq Inc. +# # This file is part of Covalent. # # Licensed under the GNU Affero General Public License 3.0 (the "License"). diff --git a/covalent/_workflow/transport_graph_ops.py b/covalent/_workflow/transport_graph_ops.py deleted file mode 100644 index ea895505d7..0000000000 --- a/covalent/_workflow/transport_graph_ops.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the GNU Affero General Public License 3.0 (the "License"). -# A copy of the License may be obtained with this software package or at -# -# https://www.gnu.org/licenses/agpl-3.0.en.html -# -# Use of this file is prohibited except in compliance with the License. Any -# modifications or derivative works of this file must retain this copyright -# notice, and modified files must contain a notice indicating that they have -# been altered from the originals. -# -# Covalent is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. -# -# Relief from the License may be granted by purchasing a commercial license. - -"""Module for transport graph operations.""" - -from collections import deque -from typing import Callable, List - -import networkx as nx - -from .._shared_files import logger -from .transport import _TransportGraph - -app_log = logger.app_log - - -class TransportGraphOps: - def __init__(self, tg): - self.tg = tg - self._status_map = {1: True, -1: False} - - @staticmethod - def _flag_successors(A: nx.MultiDiGraph, node_statuses: dict, starting_node: int): - """Flag all successors of a node (including the node itself).""" - nodes_to_invalidate = [starting_node] - for node, successors in nx.bfs_successors(A, starting_node): - nodes_to_invalidate.extend(iter(successors)) - for node in nodes_to_invalidate: - node_statuses[node] = -1 - - @staticmethod - def is_same_node(A: nx.MultiDiGraph, B: nx.MultiDiGraph, node: int) -> bool: - """Check if the node attributes are the same in both graphs.""" - return A.nodes[node] == B.nodes[node] - - @staticmethod - def is_same_edge_attributes( - A: nx.MultiDiGraph, B: nx.MultiDiGraph, parent: int, node: int - ) -> bool: - """Check if the edge attributes are the same in both graphs.""" - return A.adj[parent][node] == B.adj[parent][node] - - def copy_nodes_from(self, tg: _TransportGraph, nodes): - """Copy nodes from the transport graph in the argument.""" - for n in nodes: - for k, v in tg._graph.nodes[n].items(): - self.tg.set_node_value(n, k, v) - - @staticmethod - def _cmp_name_and_pval(A: nx.MultiDiGraph, B: nx.MultiDiGraph, node: int) -> bool: - """Default node comparison function for diffing transport graphs.""" - name_A = A.nodes[node]["name"] - name_B = B.nodes[node]["name"] - - if name_A != name_B: - return False - - val_A = A.nodes[node].get("value", None) - val_B = B.nodes[node].get("value", None) - - return val_A == val_B - - def _max_cbms( - self, - A: nx.MultiDiGraph, - B: nx.MultiDiGraph, - node_cmp: Callable = None, - edge_cmp: Callable = None, - ): - """Computes a "maximum backward-maximal common subgraph" (cbms) - Args: - A: nx.MultiDiGraph - B: nx.MultiDiGraph - node_cmp: An optional function for comparing node attributes in A and B. - Defaults to testing for equality of the attribute dictionaries - edge_cmp: An optional function for comparing the edges between two nodes. - Defaults to checking that the two sets of edges have the same attributes - Returns: A_node_status, B_node_status, where each is a dictionary - `{node: True/False}` where True means reusable. - Performs a modified BFS of A and B. - """ - if node_cmp is None: - node_cmp = self.is_same_node - if edge_cmp is None: - edge_cmp = self.is_same_edge_attributes - - A_node_status = {node_id: 0 for node_id in A.nodes} - B_node_status = {node_id: 0 for node_id in B.nodes} - app_log.debug(f"A node status: {A_node_status}") - app_log.debug(f"B node status: {B_node_status}") - - virtual_root = -1 - - if virtual_root in A.nodes or virtual_root in B.nodes: - raise RuntimeError(f"Encountered forbidden node: {virtual_root}") - - assert virtual_root not in B.nodes - - nodes_to_visit = deque() - nodes_to_visit.appendleft(virtual_root) - - # Add a temporary root - A_parentless_nodes = [node for node, deg in A.in_degree() if deg == 0] - B_parentless_nodes = [node for node, deg in B.in_degree() if deg == 0] - for node_id in A_parentless_nodes: - A.add_edge(virtual_root, node_id) - - for node_id in B_parentless_nodes: - B.add_edge(virtual_root, node_id) - - # Assume inductively that predecessors subgraphs are the same; - # this is satisfied for the root - while nodes_to_visit: - current_node = nodes_to_visit.pop() - - app_log.debug(f"Visiting node {current_node}") - for y in A.adj[current_node]: - # Don't process already failed nodes - if A_node_status[y] == -1: - continue - - # Check if y is a valid child of current_node in B - if y not in B.adj[current_node]: - app_log.debug(f"A: {y} not adjacent to node {current_node} in B") - self._flag_successors(A, A_node_status, y) - continue - - if y in B.adj[current_node] and B_node_status[y] == -1: - app_log.debug(f"A: Node {y} is marked as failed in B") - self._flag_successors(A, A_node_status, y) - continue - - # Compare edges - if not edge_cmp(A, B, current_node, y): - app_log.debug(f"Edges between {current_node} and {y} differ") - self._flag_successors(A, A_node_status, y) - self._flag_successors(B, B_node_status, y) - continue - - # Compare nodes - if not node_cmp(A, B, y): - app_log.debug(f"Attributes of node {y} differ:") - app_log.debug(f"A[y] = {A.nodes[y]}") - app_log.debug(f"B[y] = {B.nodes[y]}") - self._flag_successors(A, A_node_status, y) - self._flag_successors(B, B_node_status, y) - continue - - # Predecessors subgraphs of y are the same in A and B, so - # enqueue y if it hasn't already been visited - assert A_node_status[y] != -1 - if A_node_status[y] == 0: - A_node_status[y] = 1 - B_node_status[y] = 1 - app_log.debug(f"Enqueueing node {y}") - nodes_to_visit.appendleft(y) - - # Prune children of current_node in B that aren't valid children in A - for y in B.adj[current_node]: - if B_node_status[y] == -1: - continue - if y not in A.adj[current_node]: - app_log.debug(f"B: {y} not adjacent to node {current_node} in A") - self._flag_successors(B, B_node_status, y) - continue - if y in A.adj[current_node] and B_node_status[y] == -1: - app_log.debug(f"B: Node {y} is marked as failed in A") - self._flag_successors(B, B_node_status, y) - - A.remove_node(-1) - B.remove_node(-1) - - app_log.debug(f"A node status: {A_node_status}") - app_log.debug(f"B node status: {B_node_status}") - - for k, v in A_node_status.items(): - A_node_status[k] = self._status_map[v] - for k, v in B_node_status.items(): - B_node_status[k] = self._status_map[v] - return A_node_status, B_node_status - - def get_reusable_nodes(self, tg_new: _TransportGraph) -> List[int]: - """Find which nodes are common between the current graph and a new graph.""" - A = self.tg.get_internal_graph_copy() - B = tg_new.get_internal_graph_copy() - status_A, _ = self._max_cbms(A, B, node_cmp=self._cmp_name_and_pval) - return [k for k, v in status_A.items() if v] diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index bf1d38ffa8..baf24e48a4 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -25,10 +25,10 @@ import requests -from .._results_manager import Result +from .._dispatcher_plugins import local from .._shared_files import logger from .._shared_files.config import get_config -from .._shared_files.util_classes import Status +from .._shared_files.util_classes import RESULT_STATUS, Status app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -112,10 +112,10 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import get_result + from covalent_dispatcher._service.app import export_result response = asyncio.run_coroutine_threadsafe( - get_result(self.lattice_dispatch_id, status_only=True), + export_result(self.lattice_dispatch_id, status_only=True), self.event_loop, ).result() @@ -141,21 +141,26 @@ def _do_redispatch(self, is_pending: bool = False) -> str: new_dispatch_id: Dispatch id of the newly dispatched workflow """ - if self.use_internal_funcs: - from covalent_dispatcher import run_redispatch + # if self.use_internal_funcs: + # from covalent_dispatcher.entry_point import run_redispatch, start_dispatch - return asyncio.run_coroutine_threadsafe( - run_redispatch(self.lattice_dispatch_id, None, None, False, is_pending), - self.event_loop, - ).result() + # if is_pending: + # return asyncio.run_coroutine_threadsafe( + # start_dispatch(self.lattice_dispatch_id), + # self.event_loop, + # ).result() + # else: + # return asyncio.run_coroutine_threadsafe( + # run_redispatch(self.lattice_dispatch_id, None, None, False), + # self.event_loop, + # ).result() - from .. import redispatch - - return redispatch( - dispatch_id=self.lattice_dispatch_id, - dispatcher_addr=self.dispatcher_addr, - is_pending=is_pending, - )() + if is_pending: + return local.LocalDispatcher.start(self.lattice_dispatch_id, self.dispatcher_addr) + else: + return local.LocalDispatcher.redispatch( + self.lattice_dispatch_id, self.dispatcher_addr + )() def trigger(self) -> None: """ @@ -173,7 +178,7 @@ def trigger(self) -> None: status = self._get_status() - if status == Result.NEW_OBJ or status is None: + if status == str(RESULT_STATUS.NEW_OBJECT) or status is None: # To continue the pending dispatch same_dispatch_id = self._do_redispatch(True) app_log.debug(f"Initiating run for pending dispatch_id: {same_dispatch_id}") diff --git a/covalent_dispatcher/.gitignore b/covalent_dispatcher/.gitignore index 6d1d5b0eb5..34313c5edc 100644 --- a/covalent_dispatcher/.gitignore +++ b/covalent_dispatcher/.gitignore @@ -20,7 +20,6 @@ # Ignore results folders results/** -result_* # XML files *.xml diff --git a/covalent_dispatcher/__init__.py b/covalent_dispatcher/__init__.py index 818dab3420..8f2d0a02e3 100644 --- a/covalent_dispatcher/__init__.py +++ b/covalent_dispatcher/__init__.py @@ -18,4 +18,4 @@ # # Relief from the License may be granted by purchasing a commercial license. -from .entry_point import cancel_running_dispatch, run_dispatcher, run_redispatch +from .entry_point import cancel_running_dispatch, run_dispatcher diff --git a/covalent_dispatcher/_cli/service.py b/covalent_dispatcher/_cli/service.py index 95939cecac..a7e68840a8 100644 --- a/covalent_dispatcher/_cli/service.py +++ b/covalent_dispatcher/_cli/service.py @@ -43,6 +43,7 @@ from distributed.core import connect, rpc from covalent._shared_files.config import ConfigManager, get_config, set_config +from covalent_dispatcher._service.heartbeat import Heartbeat from .._db.datastore import DataStore from .migrate import migrate_pickled_result_object diff --git a/covalent_dispatcher/_core/__init__.py b/covalent_dispatcher/_core/__init__.py index eb2986767e..421f935830 100644 --- a/covalent_dispatcher/_core/__init__.py +++ b/covalent_dispatcher/_core/__init__.py @@ -18,5 +18,6 @@ # # Relief from the License may be granted by purchasing a commercial license. -from .data_manager import make_derived_dispatch, make_dispatch +from .data_manager import make_dispatch +from .data_modules.importer import copy_futures from .dispatcher import cancel_dispatch, run_dispatch diff --git a/covalent_dispatcher/_core/data_manager.py b/covalent_dispatcher/_core/data_manager.py index 0f015de00a..1f458b4462 100644 --- a/covalent_dispatcher/_core/data_manager.py +++ b/covalent_dispatcher/_core/data_manager.py @@ -23,35 +23,35 @@ """ import asyncio +import tempfile import traceback -import uuid -from datetime import datetime, timezone -from typing import Callable, Dict, Optional +from typing import Dict +from pydantic import ValidationError + +from covalent._dispatcher_plugins.local import LocalDispatcher from covalent._results_manager import Result from covalent._shared_files import logger -from covalent._shared_files.defaults import sublattice_prefix +from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.lattice import Lattice -from covalent._workflow.transport_graph_ops import TransportGraphOps -from .._db import load, update, upsert +from .._dal.result import Result as SRVResult +from .._dal.result import get_result_object as get_result_object_from_db from .._db.write_result_to_db import resolve_electron_id +from . import dispatcher +from .data_modules import lattice # nopycln: import +from .data_modules import dispatch, electron, graph # nopycln: import +from .data_modules import importer as manifest_importer +from .data_modules.utils import run_in_executor app_log = logger.app_log log_stack_info = logger.log_stack_info -# References to result objects of live dispatches -_registered_dispatches = {} - -# Map of dispatch_id -> message_queue for pushing node status updates -# to dispatcher -_dispatch_status_queues = {} - def generate_node_result( - node_id: int, - node_name: str, + node_id, + node_name=None, start_time=None, end_time=None, status=None, @@ -59,28 +59,7 @@ def generate_node_result( error=None, stdout=None, stderr=None, - sub_dispatch_id=None, - sublattice_result=None, ): - """ - Helper routine to prepare the node result - - Arg(s) - node_id: ID of the node in the transport graph - node_name: Name of the node - start_time: Start time of the node - end_time: Time at which the node finished executing - status: Status of the node's execution - output: Output of the node - error: Error from the node - stdout: STDOUT of a node - stderr: STDERR generated during node execution - sub_dispatch_id: Dispatch ID of the sublattice - sublattice_result: Result of the sublattice - - Return(s) - Dictionary of the inputs - """ return { "node_id": node_id, "node_name": node_name, @@ -91,300 +70,245 @@ def generate_node_result( "error": error, "stdout": stdout, "stderr": stderr, - "sub_dispatch_id": sub_dispatch_id, - "sublattice_result": sublattice_result, } -async def _handle_built_sublattice(dispatch_id: str, node_result: Dict) -> None: - """Make dispatch for sublattice node. - - Note: The status COMPLETED which invokes this function refers to the graph being built. Once this step is completed, the sublattice is ready to be dispatched. Hence, the status is changed to DISPATCHING. - - Args: - dispatch_id: Dispatch ID - node_result: Node result dictionary - - """ - try: - node_result["status"] = RESULT_STATUS.DISPATCHING_SUBLATTICE - result_object = get_result_object(dispatch_id) - sub_dispatch_id = await make_sublattice_dispatch(result_object, node_result) - node_result["sub_dispatch_id"] = sub_dispatch_id - node_result["start_time"] = datetime.now(timezone.utc) - node_result["end_time"] = None - except Exception as ex: - tb = "".join(traceback.TracebackException.from_exception(ex).format()) - node_result["status"] = RESULT_STATUS.FAILED - node_result["error"] = tb - app_log.debug(f"Failed to make sublattice dispatch: {tb}") - - # Domain: result -async def update_node_result(result_object, node_result) -> None: - """ - Updates the result object with the current node_result - - Arg(s) - result_object: Result object the current dispatch - node_result: Result of the node to be updated in the result object - - Return(s) - None - - """ - app_log.debug(f"Updating node result for {node_result['node_id']}.") - - if ( - node_result["status"] == RESULT_STATUS.COMPLETED - and node_result["node_name"].startswith(sublattice_prefix) - and not node_result["sub_dispatch_id"] - ): - app_log.debug( - f"Sublattice {node_result['node_name']} build graph completed, invoking make sublattice dispatch..." +async def update_node_result(dispatch_id, node_result): + app_log.debug("Updating node result (run_planned_workflow).") + valid_update = True + try: + node_id = node_result["node_id"] + node_status = node_result["status"] + node_info = await electron.get(dispatch_id, node_id, ["type", "sub_dispatch_id"]) + node_type = node_info["type"] + sub_dispatch_id = node_info["sub_dispatch_id"] + + # Handle returns from _build_sublattice_graph -- change + # COMPLETED -> DISPATCHING + node_result = _filter_sublattice_status( + dispatch_id, node_id, node_status, node_type, sub_dispatch_id, node_result ) - await _handle_built_sublattice(result_object.dispatch_id, node_result) - try: - update._node(result_object, **node_result) - except Exception as ex: + valid_update = await electron.update(dispatch_id, node_result) + if not valid_update: + app_log.warning( + f"Invalid status update {node_status} for node {dispatch_id}:{node_id}" + ) + return + + if node_result["status"] == RESULT_STATUS.DISPATCHING: + app_log.debug("Received sublattice dispatch") + try: + sub_dispatch_id = await _make_sublattice_dispatch(dispatch_id, node_result) + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + node_result["status"] = RESULT_STATUS.FAILED + node_result["error"] = tb + await electron.update(dispatch_id, node_result) + + except KeyError as ex: + valid_update = False app_log.exception(f"Error persisting node update: {ex}") - node_result["status"] = RESULT_STATUS.FAILED - finally: - sub_dispatch_id = node_result["sub_dispatch_id"] - detail = {"sub_dispatch_id": sub_dispatch_id} if sub_dispatch_id is not None else {} - if node_status := node_result["status"]: - dispatch_id = result_object.dispatch_id - status_queue = get_status_queue(dispatch_id) - node_id = node_result["node_id"] - await status_queue.put((node_id, node_status, detail)) - - -# Domain: result -def initialize_result_object( - json_lattice: str, parent_result_object: Result = None, parent_electron_id: int = None -) -> Result: - """Convenience function for constructing a result object from a json-serialized lattice. - - Args: - json_lattice: a JSON-serialized lattice - parent_result_object: the parent result object if json_lattice is a sublattice - parent_electron_id: the DB id of the parent electron (for sublattices) - Returns: - Result: result object - - """ - dispatch_id = get_unique_id() - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice, dispatch_id) - if parent_result_object: - result_object._root_dispatch_id = parent_result_object._root_dispatch_id + except Exception as ex: + app_log.exception(f"Error persisting node update: {ex}") + sub_dispatch_id = None + node_result["status"] = Result.FAILED - result_object._electron_id = parent_electron_id - result_object._initialize_nodes() - app_log.debug("2: Constructed result object and initialized nodes.") + finally: + if not valid_update: + return - update.persist(result_object, electron_id=parent_electron_id) - app_log.debug("Result object persisted.") + node_id = node_result["node_id"] + node_status = node_result["status"] + dispatch_id = dispatch_id - return result_object + detail = {"sub_dispatch_id": sub_dispatch_id} if sub_dispatch_id else {} + if node_status and valid_update: + await dispatcher.notify_node_status(dispatch_id, node_id, node_status, detail) # Domain: result -def get_unique_id() -> str: - """ - Get a unique ID. - - Args: - None - - Returns: - str: Unique ID - - """ - return str(uuid.uuid4()) - - -async def make_dispatch( - json_lattice: str, parent_result_object: Result = None, parent_electron_id: int = None +def _redirect_lattice( + json_lattice: str, + parent_dispatch_id: str, + parent_electron_id: int, + loop: asyncio.AbstractEventLoop, ) -> str: - """Make a dispatch from a json-serialized lattice. + """Redirect a JSON lattice through the new DAL. Args: - json_lattice: a JSON-serialized lattice. - parent_result_object: the parent result object if json_lattice is a sublattice. - parent_electron_id: the DB id of the parent electron (for sublattices). - - Returns: - Dispatch ID of the lattice. - - """ - result_object = initialize_result_object( - json_lattice, parent_result_object, parent_electron_id - ) - _register_result_object(result_object) - return result_object.dispatch_id + json_lattice: A JSON-serialized lattice. + parent_dispatch_id: The id of a sublattice's parent dispatch. - -async def make_sublattice_dispatch(result_object: Result, node_result: dict) -> str: - """Get sublattice json lattice (once the transport graph has been built) and invoke make_dispatch. - - Args: - result_object: Result object for parent dispatch of the node. - node_result: Result of the node. + This will only be triggered from either the monolithic /submit + endpoint or a monolithic sublattice dispatch. Returns: - str: Dispatch ID of the sublattice. + The dispatch manifest """ - node_id = node_result["node_id"] - json_lattice = node_result["output"].object_string - parent_electron_id = load.electron_record(result_object.dispatch_id, node_id)["id"] - app_log.debug( - f"Making sublattice dispatch for node_id {node_id} and electron_id {parent_electron_id}." - ) - return await make_dispatch(json_lattice, result_object, parent_electron_id) + lattice = Lattice.deserialize_from_json(json_lattice) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(lattice, staging_dir) + + # Trigger an internal asset pull from /tmp to object store + coro = manifest_importer.import_manifest( + manifest, + parent_dispatch_id, + parent_electron_id, + ) + filtered_manifest = manifest_importer._import_manifest( + manifest, + parent_dispatch_id, + parent_electron_id, + ) + manifest_importer._pull_assets(filtered_manifest) -def _get_result_object_from_new_lattice( - json_lattice: str, old_result_object: Result, reuse_previous_results: bool -) -> Result: - """Get new result object for re-dispatching from new lattice json. + return filtered_manifest.metadata.dispatch_id - Args: - json_lattice: JSON-serialized lattice. - old_result_object: Result object of the previous dispatch. - Returns: - Result object. - - """ - lat = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lat, get_unique_id()) - result_object._initialize_nodes() +async def make_dispatch( + json_lattice: str, parent_dispatch_id: str = None, parent_electron_id: int = None +) -> str: + return await run_in_executor( + _redirect_lattice, + json_lattice, + parent_dispatch_id, + parent_electron_id, + asyncio.get_running_loop(), + ) - if reuse_previous_results: - tg = result_object.lattice.transport_graph - tg_old = old_result_object.lattice.transport_graph - reusable_nodes = TransportGraphOps(tg_old).get_reusable_nodes(tg) - TransportGraphOps(tg).copy_nodes_from(tg_old, reusable_nodes) - return result_object +def get_result_object(dispatch_id: str, bare: bool = True) -> SRVResult: + app_log.debug(f"Getting result object from db, bare={bare}") + return get_result_object_from_db(dispatch_id, bare) -def _get_result_object_from_old_result( - old_result_object: Result, reuse_previous_results: bool -) -> Result: - """Get new result object for re-dispatching from old result object. +def finalize_dispatch(dispatch_id: str): + app_log.debug(f"Finalizing dispatch {dispatch_id}") - Args: - old_result_object: Result object of the previous dispatch. - reuse_previous_results: Whether to reuse previous results. - Returns: - Result: Result object for the new dispatch. +async def persist_result(dispatch_id: str): + await _update_parent_electron(dispatch_id) - """ - result_object = Result(old_result_object.lattice, get_unique_id()) - result_object._num_nodes = old_result_object._num_nodes - if not reuse_previous_results: - result_object._initialize_nodes() +async def _update_parent_electron(dispatch_id: str): + dispatch_attrs = await dispatch.get(dispatch_id, ["electron_id", "status", "end_time"]) + parent_eid = dispatch_attrs["electron_id"] - return result_object + if parent_eid: + dispatch_id, node_id = resolve_electron_id(parent_eid) + status = dispatch_attrs["status"] + node_result = generate_node_result( + node_id=node_id, + end_time=dispatch_attrs["end_time"], + status=status, + ) + parent_result_obj = get_result_object(dispatch_id) + app_log.debug(f"Updating sublattice parent node {dispatch_id}:{node_id}") + await update_node_result(parent_result_obj.dispatch_id, node_result) -def make_derived_dispatch( - parent_dispatch_id: str, - json_lattice: Optional[str] = None, - electron_updates: Optional[Dict[str, Callable]] = None, - reuse_previous_results: bool = False, -) -> str: - """Make a re-dispatch from a previous dispatch. +def _filter_sublattice_status( + dispatch_id, node_id, status, node_type, sub_dispatch_id, node_result +): + if status == Result.COMPLETED and node_type == "sublattice" and not sub_dispatch_id: + node_result["status"] = RESULT_STATUS.DISPATCHING + return node_result - Args: - parent_dispatch_id: Dispatch ID of the parent dispatch. - json_lattice: JSON-serialized lattice of the new dispatch. - electron_updates: Dictionary of electron updates. - reuse_previous_results: Whether to reuse previous results. - Returns: - str: Dispatch ID of the new dispatch. +async def _make_sublattice_dispatch(dispatch_id: str, node_result: dict): + try: + manifest, parent_electron_id = await run_in_executor( + _make_sublattice_dispatch_helper, + dispatch_id, + node_result, + ) - """ - if electron_updates is None: - electron_updates = {} + imported_manifest = await manifest_importer.import_manifest( + manifest=manifest, + parent_dispatch_id=dispatch_id, + parent_electron_id=parent_electron_id, + ) - old_result_object = load.get_result_object_from_storage(parent_dispatch_id) + return imported_manifest.metadata.dispatch_id - if json_lattice: - result_object = _get_result_object_from_new_lattice( - json_lattice, old_result_object, reuse_previous_results + except ValidationError as ex: + # Fall back to legacy sublattice handling + # NB: this loads the JSON sublattice in memory + json_lattice, parent_electron_id = await run_in_executor( + _legacy_sublattice_dispatch_helper, + dispatch_id, + node_result, ) - else: - result_object = _get_result_object_from_old_result( - old_result_object, reuse_previous_results + return await make_dispatch( + json_lattice, + dispatch_id, + parent_electron_id, ) - result_object.lattice.transport_graph.apply_electron_updates(electron_updates) - result_object.lattice.transport_graph.dirty_nodes = list( - result_object.lattice.transport_graph._graph.nodes - ) - update.persist(result_object) - _register_result_object(result_object) - app_log.debug(f"Redispatch result object: {result_object}") - return result_object.dispatch_id +def _legacy_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): + app_log.debug("falling back to legacy sublattice dispatch") + result_object = get_result_object(dispatch_id, bare=True) + node_id = node_result["node_id"] + parent_node = result_object.lattice.transport_graph.get_node(node_id) + bg_output = parent_node.get_value("output") -def get_result_object(dispatch_id: str) -> Result: - return _registered_dispatches[dispatch_id] + parent_electron_id = parent_node._electron_id + json_lattice = bg_output.object_string + return json_lattice, parent_electron_id -def _register_result_object(result_object: Result): - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object - _dispatch_status_queues[dispatch_id] = asyncio.Queue() +def _make_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): + """Helper function for performing DB queries related to sublattices.""" + result_object = get_result_object(dispatch_id, bare=True) + node_id = node_result["node_id"] + parent_node = result_object.lattice.transport_graph.get_node(node_id) + bg_output = parent_node.get_value("output") + manifest = ResultSchema.parse_raw(bg_output.object_string) + parent_electron_id = parent_node._electron_id -def finalize_dispatch(dispatch_id: str): - del _dispatch_status_queues[dispatch_id] - del _registered_dispatches[dispatch_id] + return manifest, parent_electron_id -def get_status_queue(dispatch_id: str): - return _dispatch_status_queues[dispatch_id] +# Common Result object queries -async def persist_result(dispatch_id: str): - result_object = get_result_object(dispatch_id) - update.persist(result_object) - await _update_parent_electron(result_object) +def generate_dispatch_result( + dispatch_id, + start_time=None, + end_time=None, + status=None, + error=None, + result=None, +): + return { + "start_time": start_time, + "end_time": end_time, + "status": status, + "error": error, + "result": result, + } -async def _update_parent_electron(result_object: Result): - if parent_eid := result_object._electron_id: - dispatch_id, node_id = resolve_electron_id(parent_eid) - status = result_object.status - if status == RESULT_STATUS.POSTPROCESSING_FAILED: - status = RESULT_STATUS.FAILED - parent_result_obj = get_result_object(dispatch_id) - node_result = generate_node_result( - node_id=node_id, - node_name=parent_result_obj.lattice.transport_graph.get_node_value(node_id, "name"), - end_time=result_object.end_time, - status=status, - output=result_object._result, - error=result_object._error, - sub_dispatch_id=load.sublattice_dispatch_id(parent_eid), - sublattice_result=result_object, - ) +# Ensure that a dispatch is only run once; in the future, also check +# if all assets have been uploaded - app_log.debug(f"Updating sublattice parent node {dispatch_id}:{node_id}") - await update_node_result(parent_result_obj, node_result) +async def ensure_dispatch(dispatch_id: str) -> bool: + """Check if a dispatch can be run. -def upsert_lattice_data(dispatch_id: str): - result_object = get_result_object(dispatch_id) - upsert.lattice_data(result_object) + The following criteria must be met: + * The dispatch has not been run before. + * (later) all assets have been uploaded + """ + return await run_in_executor( + SRVResult.ensure_run_once, + dispatch_id, + ) diff --git a/covalent_dispatcher/_core/data_modules/asset_manager.py b/covalent_dispatcher/_core/data_modules/asset_manager.py new file mode 100644 index 0000000000..e012e92d9c --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/asset_manager.py @@ -0,0 +1,93 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Utilties to transfer data between Covalent and compute backends +""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Dict + +from covalent._shared_files import logger +from covalent._shared_files.schemas.asset import AssetUpdate + +from ..._dal.result import get_result_object as get_result_object +from .utils import run_in_executor + +app_log = logger.app_log +am_pool = ThreadPoolExecutor() + + +# Consumed by Runner +async def upload_asset_for_nodes(dispatch_id: str, key: str, dest_uris: dict): + """Typical keys: "output", "deps", "call_before", "call_after", "function""" + + result_object = get_result_object(dispatch_id, bare=True) + tg = result_object.lattice.transport_graph + loop = asyncio.get_running_loop() + + futs = [] + for node_id, dest_uri in dest_uris.items(): + if dest_uri: + node = tg.get_node(node_id) + asset = node.get_asset(key, session=None) + futs.append(loop.run_in_executor(am_pool, asset.upload, dest_uri)) + + await asyncio.gather(*futs) + + +async def download_assets_for_node( + dispatch_id: str, node_id: int, asset_updates: Dict[str, AssetUpdate] +): + # Keys for src_uris: "output", "stdout", "stderr" + + result_object = get_result_object(dispatch_id, bare=True) + tg = result_object.lattice.transport_graph + node = tg.get_node(node_id) + loop = asyncio.get_running_loop() + + futs = [] + db_updates = {} + + # Mapping from asset key to (non-empty) remote uri + assets_to_download = {} + + # Prepare asset metadata update; prune empty fields + for key in asset_updates: + update = {} + asset = asset_updates[key].dict() + if asset["remote_uri"]: + assets_to_download[key] = asset["remote_uri"] + # Prune empty fields + for attr, val in asset.items(): + if val is not None: + update[attr] = val + if update: + db_updates[key] = update + + # Update metadata using the designated DB worker thread + await run_in_executor(node.update_assets, db_updates) + + for key, remote_uri in assets_to_download.items(): + asset = node.get_asset(key, session=None) + # Download assets concurrently. + futs.append(loop.run_in_executor(am_pool, asset.download, remote_uri)) + await asyncio.gather(*futs) diff --git a/covalent_dispatcher/_core/data_modules/dispatch.py b/covalent_dispatcher/_core/data_modules/dispatch.py new file mode 100644 index 0000000000..fb1ce1cff2 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/dispatch.py @@ -0,0 +1,70 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Queries involving dispatches +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_sync(dispatch_id: str, keys: List[str]) -> Dict: + refresh = False + result_object = get_result_object(dispatch_id) + return result_object.get_values(keys, refresh=refresh) + + +async def get(dispatch_id: str, keys: List[str]) -> Dict: + return await run_in_executor( + get_sync, + dispatch_id, + keys, + ) + + +def get_incomplete_tasks_sync(dispatch_id: str) -> Dict: + """Query all cancelled or failed tasks""" + result_object = get_result_object(dispatch_id) + return result_object._get_incomplete_nodes() + + +async def get_incomplete_tasks(dispatch_id: str) -> Dict: + """Query all cancelled or failed tasks in a dispatch. + + Args: + dispatch_id: The id of the dispatch + + Returns: + {"cancelled": [node_ids], "failed": [node_ids]} + """ + + return await run_in_executor(get_incomplete_tasks_sync, dispatch_id) + + +def update_sync(dispatch_id, dispatch_result): + result_object = get_result_object(dispatch_id) + result_object._update_dispatch(**dispatch_result) + + +async def update(dispatch_id, dispatch_result): + await run_in_executor(update_sync, dispatch_id, dispatch_result) diff --git a/covalent_dispatcher/_core/data_modules/electron.py b/covalent_dispatcher/_core/data_modules/electron.py new file mode 100644 index 0000000000..568d161eeb --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/electron.py @@ -0,0 +1,114 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Utilities for querying the transport graph +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_bulk_sync(dispatch_id: str, node_ids: List[int], keys: List[str]) -> List[Dict]: + result_object = get_result_object(dispatch_id) + attrs = result_object.lattice.transport_graph.get_values_for_nodes( + node_ids=node_ids, + keys=keys, + refresh=False, + ) + return attrs + + +async def get_bulk(dispatch_id: str, node_ids: List[int], keys: List[str]) -> List[Dict]: + """Query attributes for multiple electrons. + + Args: + node_ids: The list of nodes to query + keys: The list of attributes to query for each electron + + Returns: + A list of dictionaries {attr_key: attr_val}, one for + each node id, in the same order as `node_ids` + + Example: + ``` + await get_bulk( + "my_dispatch", [2, 4], ["name", "status"], + ) + ``` + will return + ``` + [ + { + "name": "task_2", "status": RESULT_STATUS.COMPLETED, + }, + { + "name": "task_4, "status": RESULT_STATUS.FAILED, + }, + ] + ``` + + """ + return await run_in_executor( + get_bulk_sync, + dispatch_id, + node_ids, + keys, + ) + + +async def get(dispatch_id: str, node_id: int, keys: List[str]) -> Dict: + """Convenience function to query attributes for an electron. + + Args: + node_id: The node to query + keys: The list of attributes to query + + Returns: + A dictionary {attr_key: attr_val} + + Example: + ``` + await get( + "my_dispatch", 2, ["name", "status"], + ) + ``` + will return + ``` + { + "name": "task_2", "status": RESULT_STATUS.COMPLETED, + } + ``` + + """ + attrs = await get_bulk(dispatch_id, [node_id], keys) + return attrs[0] + + +def update_sync(dispatch_id: str, node_result: Dict): + result_object = get_result_object(dispatch_id, bare=True) + return result_object._update_node(**node_result) + + +async def update(dispatch_id: str, node_result: Dict): + """Update a node's attributes""" + return await run_in_executor(update_sync, dispatch_id, node_result) diff --git a/covalent_dispatcher/_core/data_modules/graph.py b/covalent_dispatcher/_core/data_modules/graph.py new file mode 100644 index 0000000000..c74c9173da --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/graph.py @@ -0,0 +1,106 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Utilities for querying the transport graph +""" + + +# Note: these query static information which should be amenable to caching + +from typing import Dict, List + +import networkx as nx + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_incoming_edges_sync(dispatch_id: str, node_id: int): + """Query in-edges of a node. + + Returns: + List[Edge], where + + Edge is a dictionary with structure + source: int, + target: int, + attrs: dict + """ + + result_object = get_result_object(dispatch_id) + return result_object.lattice.transport_graph.get_incoming_edges(node_id) + + +def get_node_successors_sync( + dispatch_id: str, + node_id: int, + attrs: List[str], +) -> List[Dict]: + """Get child nodes with multiplicity. + + Parameters: + node_id: id of node + attr_keys: list of node attributes to return, such as task_group_id + + Returns: + List[Dict], where each dictionary is of the form + {"node_id": node_id, attr_key_1: node_attr[attr_key_1], ...} + + """ + + result_object = get_result_object(dispatch_id) + return result_object.lattice.transport_graph.get_successors(node_id, attrs) + + +def get_nodes_links_sync(dispatch_id: str) -> dict: + """Return the internal transport graph in NX node-link form""" + + # Need the whole NX graph here + result_object = get_result_object(dispatch_id, False) + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return nx.readwrite.node_link_data(g) + + +def get_nodes_sync(dispatch_id: str) -> List[int]: + """Return a list of all node ids in the graph.""" + result_object = get_result_object(dispatch_id, False) + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return list(g.nodes) + + +async def get_incoming_edges(dispatch_id: str, node_id: int): + return await run_in_executor(get_incoming_edges_sync, dispatch_id, node_id) + + +async def get_node_successors( + dispatch_id: str, + node_id: int, + attrs: List[str] = ["task_group_id"], +) -> List[Dict]: + return await run_in_executor(get_node_successors_sync, dispatch_id, node_id, attrs) + + +async def get_nodes_links(dispatch_id: str) -> Dict: + return await run_in_executor(get_nodes_links_sync, dispatch_id) + + +async def get_nodes(dispatch_id: str) -> List[int]: + return await run_in_executor(get_nodes_sync, dispatch_id) diff --git a/covalent_dispatcher/_core/data_modules/importer.py b/covalent_dispatcher/_core/data_modules/importer.py new file mode 100644 index 0000000000..bbceec8dd0 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/importer.py @@ -0,0 +1,152 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Functionality for importing dispatch submissions +""" + +import uuid +from typing import Optional + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent._shared_files.schemas.result import ResultSchema +from covalent_dispatcher._dal.asset import copy_asset +from covalent_dispatcher._dal.importers.result import handle_redispatch, import_result +from covalent_dispatcher._dal.result import Result as SRVResult + +from .utils import dm_pool, run_in_executor + +BASE_PATH = get_config("dispatcher.results_dir") + +app_log = logger.app_log + +# Concurrent futures for copying assets during redispatch +copy_futures = {} + + +# Domain: result +def get_unique_id() -> str: + """ + Get a unique ID. + + Args: + None + + Returns: + str: Unique ID + """ + + return str(uuid.uuid4()) + + +def _import_manifest( + res: ResultSchema, + parent_dispatch_id: Optional[str], + parent_electron_id: Optional[int], +) -> ResultSchema: + if not res.metadata.dispatch_id: + res.metadata.dispatch_id = get_unique_id() + + # Compute root_dispatch_id for sublattice dispatches + if parent_dispatch_id: + parent_result_object = SRVResult.from_dispatch_id( + dispatch_id=parent_dispatch_id, + bare=True, + ) + res.metadata.root_dispatch_id = parent_result_object.root_dispatch_id + else: + res.metadata.root_dispatch_id = res.metadata.dispatch_id + + return import_result(res, BASE_PATH, parent_electron_id) + + +def _get_all_assets(dispatch_id: str): + result_object = SRVResult.from_dispatch_id(dispatch_id, bare=True) + return result_object.get_all_assets() + + +def _pull_assets(manifest: ResultSchema) -> None: + dispatch_id = manifest.metadata.dispatch_id + assets = _get_all_assets(dispatch_id) + futs = [] + for asset in assets["lattice"]: + if asset.remote_uri: + asset.download(asset.remote_uri) + + for asset in assets["nodes"]: + if asset.remote_uri: + asset.download(asset.remote_uri) + + app_log.debug(f"imported {len(futs)} assets for dispatch {dispatch_id}") + + +async def import_manifest( + manifest: ResultSchema, + parent_dispatch_id: Optional[str], + parent_electron_id: Optional[int], +) -> ResultSchema: + filtered_manifest = await run_in_executor( + _import_manifest, manifest, parent_dispatch_id, parent_electron_id + ) + await run_in_executor(_pull_assets, filtered_manifest) + + return filtered_manifest + + +def _copy_assets(assets_to_copy): + for item in assets_to_copy: + src, dest = item + copy_asset(src, dest) + + +def _import_derived_manifest( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + filtered_manifest = _import_manifest(manifest, None, None) + filtered_manifest, assets_to_copy = handle_redispatch( + filtered_manifest, parent_dispatch_id, reuse_previous_results + ) + + dispatch_id = filtered_manifest.metadata.dispatch_id + fut = dm_pool.submit(_copy_assets, assets_to_copy) + copy_futures[dispatch_id] = fut + fut.add_done_callback(lambda x: copy_futures.pop(dispatch_id)) + + return filtered_manifest + + +async def import_derived_manifest( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + filtered_manifest = await run_in_executor( + _import_derived_manifest, + manifest, + parent_dispatch_id, + reuse_previous_results, + ) + + await run_in_executor(_pull_assets, filtered_manifest) + + return filtered_manifest diff --git a/covalent_dispatcher/_core/data_modules/job_manager.py b/covalent_dispatcher/_core/data_modules/job_manager.py index 1500e1aab3..80f6093850 100644 --- a/covalent_dispatcher/_core/data_modules/job_manager.py +++ b/covalent_dispatcher/_core/data_modules/job_manager.py @@ -102,16 +102,16 @@ async def set_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> Non await _set_job_metadata(dispatch_id, task_id, job_handle=job_handle) -async def set_cancel_result(dispatch_id: str, task_id: int, cancel_status: bool) -> None: +async def set_job_status(dispatch_id: str, task_id: int, status: str) -> None: """ - Update the cancel status of the job in the database if task cancellation is requested + Update the status of the job in the database Arg(s) dispatch_id: Dispatch ID of the lattice task_id: ID of the task in the lattice - cancel_status: True/False indicating whether the task is to be cancelled + status: status Return(s) None """ - await _set_job_metadata(dispatch_id, task_id, cancel_successful=cancel_status) + await _set_job_metadata(dispatch_id, task_id, job_status=status) diff --git a/covalent_dispatcher/_core/data_modules/lattice.py b/covalent_dispatcher/_core/data_modules/lattice.py new file mode 100644 index 0000000000..4132abef3b --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/lattice.py @@ -0,0 +1,42 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Queries involving lattice +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_sync(dispatch_id: str, keys: List[str]) -> Dict: + refresh = False + result_object = get_result_object(dispatch_id) + return result_object.lattice.get_values(keys, refresh=refresh) + + +async def get(dispatch_id: str, keys: List[str]) -> Dict: + return await run_in_executor( + get_sync, + dispatch_id, + keys, + ) diff --git a/covalent_dispatcher/_core/data_modules/utils.py b/covalent_dispatcher/_core/data_modules/utils.py new file mode 100644 index 0000000000..be1d66d05b --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/utils.py @@ -0,0 +1,35 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Utils for the data service +""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor + +# Worker thread for Datastore I/O Clamp this threadpool to one +# thread because Sqlite only supports a single writer. +dm_pool = ThreadPoolExecutor(max_workers=1) + + +def run_in_executor(func, *args) -> asyncio.Future: + loop = asyncio.get_running_loop() + return loop.run_in_executor(dm_pool, func, *args) diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index 0d853c458c..ec20553419 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -27,39 +27,37 @@ from datetime import datetime, timezone from typing import Dict, List, Tuple -from covalent._results_manager import Result +import networkx as nx + from covalent._shared_files import logger -from covalent._shared_files.defaults import parameter_prefix +from covalent._shared_files.config import get_config +from covalent._shared_files.defaults import WAIT_EDGE_NAME, parameter_prefix from covalent._shared_files.util_classes import RESULT_STATUS -from covalent_ui import result_webhook from . import data_manager as datasvc -from . import runner -from .data_modules.job_manager import set_cancel_requested +from . import runner_ng +from .data_modules import job_manager as jbmgr +from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks app_log = logger.app_log log_stack_info = logger.log_stack_info +_global_status_queue = None +_status_queues = {} +_futures = {} +_global_event_listener = None -""" -Dispatcher module is responsible for planning and dispatching workflows. The dispatcher - -1. Submits tasks to the Runner module. -2. Retrieves information using the Data Manager module. -3. Handles the tasks in terminal (COMPLETED, FAILED, CANCELLED) states. -4. Handles sublattice dispatches once the corresponding graph has been built in the Runner module. -""" +SYNC_DISPATCHES = get_config("dispatcher.use_async_dispatcher") == "false" # Domain: dispatcher -def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: +async def _get_abstract_task_inputs(dispatch_id: str, node_id: int, node_name: str) -> dict: """Return placeholders for the required inputs for a task execution. Args: + dispatch_id: id of the current dispatch node_id: Node id of this task in the transport graph. node_name: Name of the node. - result_object: Result object to be used to update and store execution related - info including the results. Returns: inputs: Input dictionary to be passed to the task with `node_id` placeholders for args, kwargs. These are to be @@ -68,16 +66,17 @@ def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Resul abstract_task_input = {"args": [], "kwargs": {}} - for parent in result_object.lattice.transport_graph.get_dependencies(node_id): - edge_data = result_object.lattice.transport_graph.get_edge_data(parent, node_id) + for edge in await datasvc.graph.get_incoming_edges(dispatch_id, node_id): + parent = edge["source"] - for _, d in edge_data.items(): - if not d.get("wait_for"): - if d["param_type"] == "arg": - abstract_task_input["args"].append((parent, d["arg_index"])) - elif d["param_type"] == "kwarg": - key = d["edge_name"] - abstract_task_input["kwargs"][key] = parent + d = edge["attrs"] + + if d["edge_name"] != WAIT_EDGE_NAME: + if d["param_type"] == "arg": + abstract_task_input["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_task_input["kwargs"][key] = parent sorted_args = sorted(abstract_task_input["args"], key=lambda x: x[1]) abstract_task_input["args"] = [x[0] for x in sorted_args] @@ -86,54 +85,40 @@ def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Resul # Domain: dispatcher -async def _handle_completed_node(result_object, node_id, pending_parents): - """ - Process the completed node in the transport graph - - Arg(s) - result_object: Result object associated with the workflow - node_id: ID of the node in the transport graph - pending_parents: Parents of this node yet to be executed - - Return(s) - List of nodes ready to be executed - """ - g = result_object.lattice.transport_graph._graph - - ready_nodes = [] +async def _handle_completed_node(dispatch_id: str, node_id: int): + next_task_groups = [] app_log.debug(f"Node {node_id} completed") - for child, edges in g.adj[node_id].items(): - for _ in edges: - pending_parents[child] -= 1 - if pending_parents[child] < 1: - app_log.debug(f"Queuing node {child} for execution") - ready_nodes.append(child) - return ready_nodes + parent_gid = (await datasvc.electron.get(dispatch_id, node_id, ["task_group_id"]))[ + "task_group_id" + ] + for child in await datasvc.graph.get_node_successors(dispatch_id, node_id): + node_id = child["node_id"] + gid = child["task_group_id"] + app_log.debug(f"dispatch {dispatch_id}: parent gid {parent_gid}, child gid {gid}") + if parent_gid != gid: + now_pending = await _pending_parents.decrement(dispatch_id, gid) + if now_pending < 1: + app_log.debug(f"Queuing task group {gid} for execution") + next_task_groups.append(gid) + + return next_task_groups # Domain: dispatcher -async def _handle_failed_node(result_object, node_id): - result_object._task_failed = True - result_object._end_time = datetime.now(timezone.utc) - app_log.debug(f"Node {result_object.dispatch_id}:{node_id} failed") +async def _handle_failed_node(dispatch_id: str, node_id: int): + app_log.debug(f"Node {dispatch_id}:{node_id} failed") app_log.debug("8A: Failed node upsert statement (run_planned_workflow)") - datasvc.upsert_lattice_data(result_object.dispatch_id) - await result_webhook.send_update(result_object) # Domain: dispatcher -async def _handle_cancelled_node(result_object, node_id): - result_object._task_cancelled = True - result_object._end_time = datetime.now(timezone.utc) - app_log.debug(f"Node {result_object.dispatch_id}:{node_id} cancelled") +async def _handle_cancelled_node(dispatch_id: str, node_id: int): + app_log.debug(f"Node {dispatch_id}:{node_id} cancelled") app_log.debug("9: Cancelled node upsert statement (run_planned_workflow)") - datasvc.upsert_lattice_data(result_object.dispatch_id) - await result_webhook.send_update(result_object) # Domain: dispatcher -async def _get_initial_tasks_and_deps(result_object: Result) -> Tuple[int, int, Dict]: +async def _get_initial_tasks_and_deps(dispatch_id: str) -> Tuple[int, int, Dict]: """Compute the initial batch of tasks to submit and initialize each task's dep count Returns: (num_tasks, ready_nodes, pending_parents) where num_tasks is @@ -144,188 +129,141 @@ async def _get_initial_tasks_and_deps(result_object: Result) -> Tuple[int, int, """ - num_tasks = 0 - ready_nodes = [] + # Number of pending predecessor nodes for each task group pending_parents = {} - g = result_object.lattice.transport_graph._graph - for node_id, d in g.in_degree(): - app_log.debug(f"Node {node_id} has {d} parents") + g_node_link = await datasvc.graph.get_nodes_links(dispatch_id) + g = nx.readwrite.node_link_graph(g_node_link) + + # Topologically sort each task group + sorted_task_groups = {} + for node_id in nx.topological_sort(g): + gid = g.nodes[node_id]["task_group_id"] + if gid not in sorted_task_groups: + sorted_task_groups[gid] = [node_id] + pending_parents[gid] = 0 + else: + sorted_task_groups[gid].append(node_id) - pending_parents[node_id] = d - num_tasks += 1 - if d == 0: - ready_nodes.append(node_id) + for node_id in g.nodes: + parent_gid = g.nodes[node_id]["task_group_id"] + for succ, datadict in g.adj[node_id].items(): + child_gid = g.nodes[succ]["task_group_id"] + n_edges = len(datadict.keys()) + if parent_gid != child_gid: + pending_parents[child_gid] += n_edges - return num_tasks, ready_nodes, pending_parents + initial_task_groups = [gid for gid, d in pending_parents.items() if d == 0] + app_log.debug(f"Sorted task groups: {sorted_task_groups}") + return initial_task_groups, pending_parents, sorted_task_groups # Domain: dispatcher -async def _submit_task(result_object, node_id): +async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_group_id: int): + # Handle parameter nodes # Get name of the node for the current task - node_name = result_object.lattice.transport_graph.get_node_value(node_id, "name") - node_status = result_object.lattice.transport_graph.get_node_value(node_id, "status") + node_name = (await datasvc.electron.get(dispatch_id, sorted_nodes[0], ["name"]))["name"] + app_log.debug(f"7A: Node name: {node_name} (run_planned_workflow).") # Handle parameter nodes if node_name.startswith(parameter_prefix): - output = result_object.lattice.transport_graph.get_node_value(node_id, "value") - timestamp = datetime.now(timezone.utc) - node_result = datasvc.generate_node_result( - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.COMPLETED, - output=output, - ) - await datasvc.update_node_result(result_object, node_result) - app_log.debug(f"Updated parameter node {node_id}.") - - elif node_status == RESULT_STATUS.COMPLETED: - timestamp = datetime.now(timezone.utc) - output = result_object.lattice.transport_graph.get_node_value(node_id, "output") - node_result = datasvc.generate_node_result( - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.COMPLETED, - output=output, - ) - await datasvc.update_node_result(result_object, node_result) - app_log.debug(f"Skipped completed node execution {node_name}.") + if len(sorted_nodes) > 1: + raise RuntimeError("Parameter nodes cannot be packed") + + app_log.debug("7C: Encountered parameter node {node_id}.") + app_log.debug("8: Starting update node (run_planned_workflow).") + + ts = datetime.now(timezone.utc) + node_result = { + "node_id": sorted_nodes[0], + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + await datasvc.update_node_result(dispatch_id, node_result) + app_log.debug("8A: Update node success (run_planned_workflow).") else: - # Gather inputs and dispatch task - app_log.debug(f"Gathering inputs for task {node_id}.") - - abs_task_input = _get_abstract_task_inputs(node_id, node_name, result_object) - executor = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")[ - "executor" - ] - executor_data = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")[ - "executor_data" - ] - coro = runner.run_abstract_task( - dispatch_id=result_object.dispatch_id, - node_id=node_id, - executor=[executor, executor_data], - node_name=node_name, - abstract_inputs=abs_task_input, + # Gather inputs for each task and send the task spec sequence to the runner + task_specs = [] + known_nodes = [] + + # Skip the group if all task outputs can be reused from a + # previous dispatch (for redispatch). + statuses = await datasvc.electron.get_bulk(dispatch_id, sorted_nodes, ["status"]) + incomplete = list( + filter(lambda record: record["status"] != RESULT_STATUS.PENDING_REUSE, statuses) ) - app_log.debug(f"Creating task {node_id}.") - asyncio.create_task(coro) - - -# Domain: dispatcher -async def _run_planned_workflow(result_object: Result, status_queue: asyncio.Queue) -> Result: - """ - Run the workflow in the topological order of their position on the - transport graph. Does this in an asynchronous manner so that nodes - at the same level are executed in parallel. Also updates the status - of the whole workflow execution. - - Args: - result_object: Result object being used for current dispatch - status_queue: message queue for notifying the main loop of status updates - - Returns: - None - """ - app_log.debug("Starting _run_planned_workflow ...") - result_object._status = RESULT_STATUS.RUNNING - result_object._start_time = datetime.now(timezone.utc) - datasvc.upsert_lattice_data(result_object.dispatch_id) - app_log.debug(f"Wrote lattice status {result_object._status} to DB.") - - tasks_left, initial_nodes, pending_parents = await _get_initial_tasks_and_deps(result_object) - unresolved_tasks = 0 + if len(incomplete) > 0: + for node_id in sorted_nodes: + app_log.debug(f"Gathering inputs for task {node_id} (run_planned_workflow).") + + abs_task_input = await _get_abstract_task_inputs(dispatch_id, node_id, node_name) + + executor_attrs = await datasvc.electron.get( + dispatch_id, + node_id, + ["executor", "executor_data"], + ) + selected_executor = executor_attrs["executor"] + selected_executor_data = executor_attrs["executor_data"] + task_spec = { + "function_id": node_id, + "name": node_name, + "args_ids": abs_task_input["args"], + "kwargs_ids": abs_task_input["kwargs"], + } + known_nodes += abs_task_input["args"] + known_nodes += list(abs_task_input["kwargs"].values()) + task_specs.append(task_spec) - for node_id in initial_nodes: - unresolved_tasks += 1 - await _submit_task(result_object, node_id) - - while unresolved_tasks > 0: - app_log.debug(f"{tasks_left} tasks left to complete.") - app_log.debug(f"Waiting to hear from {unresolved_tasks} tasks.") - - node_id, node_status, detail = await status_queue.get() - - app_log.debug( - f"Status queue msg for node id {node_id}: {node_status} with detail {detail}." - ) - - if node_status == RESULT_STATUS.RUNNING: - continue - - # Note: A node status can only be 'DISPATCHING' if it is a sublattice and the corresponding graph has been built. - if node_status == RESULT_STATUS.DISPATCHING_SUBLATTICE: - sub_dispatch_id = detail["sub_dispatch_id"] - run_dispatch(sub_dispatch_id) app_log.debug( - f"Submitted sublattice (dispatch id: {sub_dispatch_id}) to run_dispatch." + f"Submitting task group {dispatch_id}:{task_group_id} ({len(sorted_nodes)} tasks) to runner" + ) + app_log.debug(f"Using new runner for task group {task_group_id}") + + known_nodes = list(set(known_nodes)) + coro = runner_ng.run_abstract_task_group( + dispatch_id=dispatch_id, + task_group_id=task_group_id, + task_seq=task_specs, + known_nodes=known_nodes, + selected_executor=[selected_executor, selected_executor_data], ) - continue - - unresolved_tasks -= 1 - - if node_status == RESULT_STATUS.COMPLETED: - tasks_left -= 1 - ready_nodes = await _handle_completed_node(result_object, node_id, pending_parents) - for node_id in ready_nodes: - unresolved_tasks += 1 - await _submit_task(result_object, node_id) - - if node_status == RESULT_STATUS.FAILED: - await _handle_failed_node(result_object, node_id) - continue - - if node_status == RESULT_STATUS.CANCELLED: - await _handle_cancelled_node(result_object, node_id) - continue - - if result_object._task_failed or result_object._task_cancelled: - app_log.debug(f"Workflow {result_object.dispatch_id} cancelled or failed") - failed_nodes = result_object._get_failed_nodes() - failed_nodes = map(lambda x: f"{x[0]}: {x[1]}", failed_nodes) - failed_nodes_msg = "\n".join(failed_nodes) - result_object._error = "The following tasks failed:\n" + failed_nodes_msg - result_object._status = ( - RESULT_STATUS.FAILED if result_object._task_failed else RESULT_STATUS.CANCELLED - ) - return result_object - - app_log.debug( - f"Tasks for {result_object.dispatch_id} finished running. Updating result webhook ..." - ) - await result_webhook.send_update(result_object) - return result_object - -def _plan_workflow(result_object: Result) -> None: + asyncio.create_task(coro) + else: + ts = datetime.now(timezone.utc) + for node_id in sorted_nodes: + app_log.debug(f"Skipping already completed node {dispatch_id}:{node_id}") + node_result = { + "node_id": node_id, + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + await datasvc.update_node_result(dispatch_id, node_result) + app_log.debug("8A: Update node success (run_planned_workflow).") + + +async def _plan_workflow(dispatch_id: str) -> None: """ Function to plan a workflow according to a schedule. Planning means to decide which executors (along with their arguments) will be used by each node. Args: - result_object: Result object being used for current dispatch + dispatch_id: id of current dispatch Returns: None """ - if result_object.lattice.get_metadata("schedule"): - # Custom scheduling logic of the format: - # scheduled_executors = get_schedule(result_object) + pass - # for node_id, executor in scheduled_executors.items(): - # result_object.lattice.transport_graph.set_node_value(node_id, "executor", executor) - pass - -async def run_workflow(result_object: Result) -> Result: +async def run_workflow(dispatch_id: str, wait: bool = SYNC_DISPATCHES) -> RESULT_STATUS: """ Plan and run the workflow by loading the result object corresponding to the dispatch id and retrieving essential information from it. @@ -338,35 +276,46 @@ async def run_workflow(result_object: Result) -> Result: Returns: The result object from the workflow execution - """ - app_log.debug(f"Starting run_workflow for dispatch id {result_object.dispatch_id} ...") - if result_object.status == RESULT_STATUS.COMPLETED: - datasvc.finalize_dispatch(result_object.dispatch_id) - return result_object + + app_log.debug("Inside run_workflow.") + + # Ensure that the dispatch is run at most once + can_run = await datasvc.ensure_dispatch(dispatch_id) + + if not can_run: + result_info = await datasvc.dispatch.get(dispatch_id, ["status"]) + dispatch_status = result_info["status"] + app_log.debug(f"Cannot start dispatch {dispatch_id}: current status {dispatch_status}") + return dispatch_status try: - _plan_workflow(result_object) - status_queue = datasvc.get_status_queue(result_object.dispatch_id) - result_object = await _run_planned_workflow(result_object, status_queue) + await _plan_workflow(dispatch_id) - except Exception as ex: - app_log.error(f"Exception during _run_planned_workflow: {ex}") + if wait: + fut = asyncio.Future() + _futures[dispatch_id] = fut - error_msg = "".join(traceback.TracebackException.from_exception(ex).format()) - result_object._status = RESULT_STATUS.FAILED - result_object._error = error_msg - result_object._end_time = datetime.now(timezone.utc) + dispatch_status = await _submit_initial_tasks(dispatch_id) + + if wait: + app_log.debug(f"Waiting for dispatch {dispatch_id}") + dispatch_status = await fut + else: + app_log.debug(f"Running dispatch {dispatch_id} asynchronously") + + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) finally: - await datasvc.persist_result(result_object.dispatch_id) - datasvc.finalize_dispatch(result_object.dispatch_id) + if dispatch_status != RESULT_STATUS.RUNNING: + datasvc.finalize_dispatch(dispatch_id) - return result_object + return dispatch_status # Domain: dispatcher -async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: +async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = []) -> None: """ Cancel an entire dispatch or a specific set of tasks within it @@ -377,38 +326,216 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: Return(s) None """ - if task_ids is None: - task_ids = [] if not dispatch_id: return - tg = datasvc.get_result_object(dispatch_id=dispatch_id).lattice.transport_graph if task_ids: app_log.debug(f"Cancelling tasks {task_ids} in dispatch {dispatch_id}") else: - task_ids = list(tg._graph.nodes) + task_ids = await datasvc.graph.get_nodes(dispatch_id) + app_log.debug(f"Cancelling dispatch {dispatch_id}") - await set_cancel_requested(dispatch_id, task_ids) - await runner.cancel_tasks(dispatch_id, task_ids) + await jbmgr.set_cancel_requested(dispatch_id, task_ids) + await runner_ng.cancel_tasks(dispatch_id, task_ids) # Recursively cancel running sublattice dispatches - sub_ids = list(map(lambda x: tg.get_node_value(x, "sub_dispatch_id"), task_ids)) + attrs = await datasvc.electron.get_bulk(dispatch_id, task_ids, ["sub_dispatch_id"]) + sub_ids = list(map(lambda x: x["sub_dispatch_id"], attrs)) for sub_dispatch_id in sub_ids: await cancel_dispatch(sub_dispatch_id) def run_dispatch(dispatch_id: str) -> asyncio.Future: - """ - Run the workflow and return immediately + return asyncio.create_task(run_workflow(dispatch_id)) - Arg(s) - dispatch_id: Dispatch ID of the lattice - Return(s) - asyncio.Future +async def notify_node_status( + dispatch_id: str, node_id: int, status: RESULT_STATUS, detail: Dict = {} +): + msg = { + "dispatch_id": dispatch_id, + "node_id": node_id, + "status": status, + "detail": detail, + } - """ - app_log.debug(f"Running dispatch with dispatch_id: {dispatch_id}.") - result_object = datasvc.get_result_object(dispatch_id) - return asyncio.create_task(run_workflow(result_object)) + await _global_status_queue.put(msg) + + +async def _finalize_dispatch(dispatch_id: str): + await _clear_caches(dispatch_id) + app_log.debug(f"Removed unresolved counter for {dispatch_id}") + + incomplete_tasks = await datasvc.dispatch.get_incomplete_tasks(dispatch_id) + failed = incomplete_tasks["failed"] + cancelled = incomplete_tasks["cancelled"] + if failed or cancelled: + app_log.debug(f"Workflow {dispatch_id} cancelled or failed") + failed_nodes = failed + failed_nodes = map(lambda x: f"{x[0]}: {x[1]}", failed_nodes) + failed_nodes_msg = "\n".join(failed_nodes) + error_msg = "The following tasks failed:\n" + failed_nodes_msg + ts = datetime.now(timezone.utc) + status = RESULT_STATUS.FAILED if failed else RESULT_STATUS.CANCELLED + result_update = datasvc.generate_dispatch_result( + dispatch_id, + status=status, + error=error_msg, + end_time=ts, + ) + await datasvc.dispatch.update(dispatch_id, result_update) + + app_log.debug("8: All tasks finished running (run_planned_workflow)") + + app_log.debug("Workflow already postprocessed") + + result_info = await datasvc.dispatch.get(dispatch_id, ["status"]) + return result_info["status"] + + +async def _initialize_caches(dispatch_id, pending_parents, sorted_task_groups): + for gid, indegree in pending_parents.items(): + await _pending_parents.set_pending(dispatch_id, gid, indegree) + + for gid, sorted_nodes in sorted_task_groups.items(): + await _sorted_task_groups.set_task_group(dispatch_id, gid, sorted_nodes) + + await _unresolved_tasks.set_unresolved(dispatch_id, 0) + + +async def _submit_initial_tasks(dispatch_id: str): + app_log.debug("3: Inside run_planned_workflow (run_planned_workflow).") + dispatch_result = datasvc.generate_dispatch_result( + dispatch_id, start_time=datetime.now(timezone.utc), status=RESULT_STATUS.RUNNING + ) + await datasvc.dispatch.update(dispatch_id, dispatch_result) + + app_log.debug(f"4: Workflow status changed to running {dispatch_id} (run_planned_workflow).") + app_log.debug("5: Wrote lattice status to DB (run_planned_workflow).") + + initial_groups, pending_parents, sorted_task_groups = await _get_initial_tasks_and_deps( + dispatch_id + ) + + await _initialize_caches(dispatch_id, pending_parents, sorted_task_groups) + + for gid in initial_groups: + sorted_nodes = sorted_task_groups[gid] + app_log.debug(f"Sorted nodes group group {gid}: {sorted_nodes}") + await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes)) + + for gid in initial_groups: + sorted_nodes = sorted_task_groups[gid] + await _submit_task_group(dispatch_id, sorted_nodes, gid) + + return RESULT_STATUS.RUNNING + + +async def _handle_node_status_update(dispatch_id, node_id, node_status, detail): + app_log.debug(f"Received node status update {node_id}: {node_status}") + + if node_status == RESULT_STATUS.RUNNING: + return + + if node_status == RESULT_STATUS.DISPATCHING: + sub_dispatch_id = detail["sub_dispatch_id"] + run_dispatch(sub_dispatch_id) + app_log.debug(f"Running sublattice dispatch {sub_dispatch_id}") + + return + + # Terminal node statuses + + if node_status == RESULT_STATUS.COMPLETED: + next_task_groups = await _handle_completed_node(dispatch_id, node_id) + for gid in next_task_groups: + sorted_nodes = await _sorted_task_groups.get_task_group(dispatch_id, gid) + await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes)) + await _submit_task_group(dispatch_id, sorted_nodes, gid) + + if node_status == RESULT_STATUS.FAILED: + await _handle_failed_node(dispatch_id, node_id) + + if node_status == RESULT_STATUS.CANCELLED: + await _handle_cancelled_node(dispatch_id, node_id) + + # Decrement after any increments to avoid race with + # finalize_dispatch() + await _unresolved_tasks.decrement(dispatch_id) + + +async def _handle_dispatch_exception(dispatch_id: str, ex: Exception) -> RESULT_STATUS: + error_msg = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.exception(f"Exception during _run_planned_workflow: {error_msg}") + + dispatch_result = datasvc.generate_dispatch_result( + dispatch_id, + end_time=datetime.now(timezone.utc), + status=RESULT_STATUS.FAILED, + error=error_msg, + ) + + await datasvc.dispatch.update(dispatch_id, dispatch_result) + return RESULT_STATUS.FAILED + + +# msg = { +# "dispatch_id": dispatch_id, +# "node_id": node_id, +# "status": status, +# "detail": detail, +# } +async def _node_event_listener(): + app_log.debug("Starting event listener") + while True: + msg = await _global_status_queue.get() + + asyncio.create_task(_handle_event(msg)) + + +async def _handle_event(msg: Dict): + dispatch_id = msg["dispatch_id"] + node_id = msg["node_id"] + node_status = msg["status"] + detail = msg["detail"] + + try: + await _handle_node_status_update(dispatch_id, node_id, node_status, detail) + + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) + await datasvc.persist_result(dispatch_id) + fut = _futures.get(dispatch_id, None) + if fut: + fut.set_result(dispatch_status) + return dispatch_status + + unresolved = await _unresolved_tasks.get_unresolved(dispatch_id) + if unresolved < 1: + app_log.debug("Finalizing dispatch") + try: + dispatch_status = await _finalize_dispatch(dispatch_id) + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) + + finally: + await datasvc.persist_result(dispatch_id) + fut = _futures.get(dispatch_id, None) + if fut: + fut.set_result(dispatch_status) + + return dispatch_status + + +async def _clear_caches(dispatch_id: str): + """Clean up all keys in caches.""" + await _unresolved_tasks.remove(dispatch_id) + + g_node_link = await datasvc.graph.get_nodes_links(dispatch_id) + g = nx.readwrite.node_link_graph(g_node_link) + task_groups = set([g.nodes[i]["task_group_id"] for i in g.nodes]) + for gid in task_groups: + # Clean up no longer referenced keys + await _pending_parents.remove(dispatch_id, gid) + await _sorted_task_groups.remove(dispatch_id, gid) diff --git a/covalent_dispatcher/_core/dispatcher_modules/__init__.py b/covalent_dispatcher/_core/dispatcher_modules/__init__.py new file mode 100644 index 0000000000..9d1b05526a --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/covalent_dispatcher/_core/dispatcher_modules/caches.py b/covalent_dispatcher/_core/dispatcher_modules/caches.py new file mode 100644 index 0000000000..c5f0351b70 --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/caches.py @@ -0,0 +1,105 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Helper classes for the dispatcher +""" + +from .store import _DictStore, _KeyValueBase + + +def _pending_parents_key(dispatch_id: str, node_id: int): + return f"pending-parents-{dispatch_id}:{node_id}" + + +def _unresolved_tasks_key(dispatch_id: str): + return f"unresolved-{dispatch_id}" + + +def _task_groups_key(dispatch_id: str, task_group_id: int): + return f"task-groups-{dispatch_id}:{task_group_id}" + + +class _UnresolvedTasksCache: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_unresolved(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.get(key) + + async def set_unresolved(self, dispatch_id: str, val: int): + key = _unresolved_tasks_key(dispatch_id) + await self._store.insert(key, val) + + async def increment(self, dispatch_id: str, interval: int = 1): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.increment(key, interval) + + async def decrement(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.increment(key, -1) + + async def remove(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + await self._store.remove(key) + + +class _PendingParentsCache: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_pending(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + return await self._store.get(key) + + async def set_pending(self, dispatch_id: str, task_group_id: int, val: int): + key = _pending_parents_key(dispatch_id, task_group_id) + await self._store.insert(key, val) + + async def decrement(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + return await self._store.increment(key, -1) + + async def remove(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + await self._store.remove(key) + + +class _SortedTaskGroups: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_task_group(self, dispatch_id: str, task_group_id: int): + key = _task_groups_key(dispatch_id, task_group_id) + return await self._store.get(key) + + async def set_task_group(self, dispatch_id: str, task_group_id: int, sorted_nodes: list): + key = _task_groups_key(dispatch_id, task_group_id) + await self._store.insert(key, sorted_nodes) + + async def remove(self, dispatch_id: str, task_group_id: int): + key = _task_groups_key(dispatch_id, task_group_id) + await self._store.remove(key) + + +_pending_parents = _PendingParentsCache() +_unresolved_tasks = _UnresolvedTasksCache() +_sorted_task_groups = _SortedTaskGroups() diff --git a/covalent_dispatcher/_core/dispatcher_modules/store.py b/covalent_dispatcher/_core/dispatcher_modules/store.py new file mode 100644 index 0000000000..8c98c3275d --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/store.py @@ -0,0 +1,70 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Simple Key-Value store base +""" + + +class _KeyValueBase: + async def get(self, key): + raise NotImplementedError + + async def insert(self, key, val): + raise NotImplementedError + + async def belongs(self, key): + raise NotImplementedError + + async def remove(self, key): + raise NotImplementedError + + async def increment(self, key: str, delta: int) -> int: + """Increments value for `key` by amount `delta` + + Parameters: + key: the value to change + delta: the amount to change (can be negative) + Returns: + The new value + """ + + raise NotImplementedError + + +class _DictStore(_KeyValueBase): + def __init__(self): + self._store = {} + + async def get(self, key): + return self._store[key] + + async def insert(self, key, val): + self._store[key] = val + + async def belongs(self, key): + return key in self._store + + async def remove(self, key): + del self._store[key] + + async def increment(self, key, delta: int): + self._store[key] += delta + return self._store[key] diff --git a/covalent_dispatcher/_core/execution.py b/covalent_dispatcher/_core/execution.py index 20c3315d1f..14a8a4a034 100644 --- a/covalent_dispatcher/_core/execution.py +++ b/covalent_dispatcher/_core/execution.py @@ -22,12 +22,45 @@ Defines the core functionality of the dispatcher """ +# Legacy imports +# from .dispatcher import ( +# _build_sublattice_graph, +# _dispatch_sync_sublattice, +# _get_abstract_task_inputs, +# _handle_cancelled_node, +# _handle_completed_node, +# _handle_failed_node, +# _initialize_deps_and_queue, +# _plan_workflow, +# _post_process, +# _postprocess_workflow, +# _run_planned_workflow, +# _submit_task, +# cancel_workflow, +# run_workflow, +# ) +# from .result import ( +# _update_node_result, +# generate_node_result, +# get_unique_id, +# initialize_result_object, +# ) +# from .runner import ( +# _gather_deps, +# _get_task_input_values, +# _get_task_inputs, +# _run_abstract_task, +# _run_task, +# _run_task_and_update, +# ) + + from covalent._results_manager import Result -from . import dispatcher, runner +from . import runner -def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: +async def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: """ Return the required inputs for a task execution. This makes sure that any node with child nodes isn't executed twice and fetches the @@ -44,8 +77,24 @@ def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dic and any parent node execution results if present. """ - abstract_inputs = dispatcher._get_abstract_task_inputs(node_id, node_name, result_object) - input_values = runner._get_task_input_values(result_object, abstract_inputs) + abstract_inputs = {"args": [], "kwargs": {}} + + for parent in result_object.lattice.transport_graph.get_dependencies(node_id): + edge_data = result_object.lattice.transport_graph.get_edge_data(parent, node_id) + # value = result_object.lattice.transport_graph.get_node_value(parent, "output") + + for e_key, d in edge_data.items(): + if not d.get("wait_for"): + if d["param_type"] == "arg": + abstract_inputs["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_inputs["kwargs"][key] = parent + + sorted_args = sorted(abstract_inputs["args"], key=lambda x: x[1]) + abstract_inputs["args"] = [x[0] for x in sorted_args] + + input_values = await runner._get_task_input_values(result_object.dispatch_id, abstract_inputs) abstract_args = abstract_inputs["args"] abstract_kwargs = abstract_inputs["kwargs"] diff --git a/covalent_dispatcher/_core/runner.py b/covalent_dispatcher/_core/runner.py index c87d5111f9..615df4de47 100644 --- a/covalent_dispatcher/_core/runner.py +++ b/covalent_dispatcher/_core/runner.py @@ -23,132 +23,61 @@ """ import asyncio -import json import traceback -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from functools import partial -from typing import Any, Dict, List, Literal, Tuple, Union +from typing import Any, Dict, List, Tuple -from covalent._results_manager import Result from covalent._shared_files import logger from covalent._shared_files.config import get_config from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow import DepsBash, DepsCall, DepsPip -from covalent.executor import _executor_manager -from covalent.executor.base import AsyncBaseExecutor, wrapper_fn +from covalent.executor.utils.wrappers import wrapper_fn from . import data_manager as datasvc -from .data_modules.job_manager import get_jobs_metadata, set_cancel_result from .runner_modules import executor_proxy +from .runner_modules.cancel import cancel_tasks # nopycln: import +from .runner_modules.utils import get_executor app_log = logger.app_log log_stack_info = logger.log_stack_info debug_mode = get_config("sdk.log_level") == "debug" -_cancel_threadpool = ThreadPoolExecutor() - - -# Domain: runner -def get_executor( - executor: Union[Tuple, List], - loop: asyncio.BaseEventLoop = None, - cancel_pool: ThreadPoolExecutor = None, -) -> AsyncBaseExecutor: - """Get unpacked and initialized executor object. - - Args: - executor: Tuple containing short name and object dictionary for the executor. - loop: Running event loop. Defaults to None. - cancel_pool: Threadpool for cancelling tasks. Defaults to None. - - Returns: - Executor object. - - """ - short_name, object_dict = executor - executor = _executor_manager.get_executor(short_name) - executor.from_dict(object_dict) - executor._init_runtime(loop=loop, cancel_pool=cancel_pool) - - return executor - # Domain: runner # to be called by _run_abstract_task -def _get_task_input_values(result_object: Result, abs_task_inputs: dict) -> dict: - """ - Retrieve the input values from the result_object for the task - - Arg(s) - result_object: Result object of the workflow - abs_task_inputs: Task inputs dictionary - - Return(s) - node_values: Dictionary of task inputs - - """ +async def _get_task_input_values(dispatch_id: str, abs_task_inputs: dict) -> dict: node_values = {} args = abs_task_inputs["args"] for node_id in args: - value = result_object.lattice.transport_graph.get_node_value(node_id, "output") + value = (await datasvc.electron.get(dispatch_id, node_id, ["output"]))["output"] node_values[node_id] = value kwargs = abs_task_inputs["kwargs"] - for _, node_id in kwargs.items(): - value = result_object.lattice.transport_graph.get_node_value(node_id, "output") + for key, node_id in kwargs.items(): + value = (await datasvc.electron.get(dispatch_id, node_id, ["output"]))["output"] node_values[node_id] = value return node_values -# Domain: runner -async def run_abstract_task( - dispatch_id: str, - node_id: int, - node_name: str, - abstract_inputs: Dict, - executor: Any, -) -> None: - node_result = await _run_abstract_task( - dispatch_id=dispatch_id, - node_id=node_id, - node_name=node_name, - abstract_inputs=abstract_inputs, - executor=executor, - ) - - result_object = datasvc.get_result_object(dispatch_id) - await datasvc.update_node_result(result_object, node_result) - - # Domain: runner async def _run_abstract_task( dispatch_id: str, node_id: int, node_name: str, abstract_inputs: Dict, - executor: Any, + selected_executor: Any, ) -> None: # Resolve abstract task and inputs to their concrete (serialized) values - result_object = datasvc.get_result_object(dispatch_id) timestamp = datetime.now(timezone.utc) try: - cancel_req = await executor_proxy._get_cancel_requested(dispatch_id, node_id) - if cancel_req: - app_log.debug(f"Don't run cancelled task {dispatch_id}:{node_id}") - return datasvc.generate_node_result( - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.CANCELLED, - ) - serialized_callable = result_object.lattice.transport_graph.get_node_value( - node_id, "function" - ) - input_values = _get_task_input_values(result_object, abstract_inputs) + serialized_callable = (await datasvc.electron.get(dispatch_id, node_id, ["function"]))[ + "function" + ] + + input_values = await _get_task_input_values(dispatch_id, abstract_inputs) abstract_args = abstract_inputs["args"] abstract_kwargs = abstract_inputs["kwargs"] @@ -158,33 +87,33 @@ async def _run_abstract_task( app_log.debug(f"Collecting deps for task {node_id}") - call_before, call_after = _gather_deps(result_object, node_id) + call_before, call_after = await _gather_deps(dispatch_id, node_id) except Exception as ex: app_log.error(f"Exception when trying to resolve inputs or deps: {ex}") - return datasvc.generate_node_result( + node_result = datasvc.generate_node_result( node_id=node_id, - node_name=node_name, start_time=timestamp, end_time=timestamp, status=RESULT_STATUS.FAILED, error=str(ex), ) + return node_result + node_result = datasvc.generate_node_result( node_id=node_id, - node_name=node_name, start_time=timestamp, status=RESULT_STATUS.RUNNING, ) app_log.debug(f"7: Marking node {node_id} as running (_run_abstract_task)") - await datasvc.update_node_result(result_object, node_result) + await datasvc.update_node_result(dispatch_id, node_result) return await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=node_id, serialized_callable=serialized_callable, - executor=executor, + selected_executor=selected_executor, node_name=node_name, call_before=call_before, call_after=call_after, @@ -194,11 +123,11 @@ async def _run_abstract_task( # Domain: runner async def _run_task( - result_object: Result, + dispatch_id: str, node_id: int, inputs: Dict, serialized_callable: Any, - executor: Any, + selected_executor: Any, call_before: List, call_after: List, node_name: str, @@ -213,39 +142,52 @@ async def _run_task( Args: inputs: Inputs for the task. - result_object: Result object being used for current dispatch node_id: Node id of the task to be executed. Returns: None - """ - dispatch_id = result_object.dispatch_id - results_dir = result_object.results_dir + + dispatch_info = await datasvc.dispatch.get(dispatch_id, ["results_dir"]) + results_dir = dispatch_info["results_dir"] # Instantiate the executor from JSON try: - executor = get_executor(executor=executor, loop=asyncio.get_running_loop()) - + app_log.debug(f"Instantiating executor for {dispatch_id}:{node_id}") + executor = get_executor( + node_id=node_id, + selected_executor=selected_executor, + loop=asyncio.get_running_loop(), + pool=None, + ) except Exception as ex: tb = "".join(traceback.TracebackException.from_exception(ex).format()) app_log.debug("Exception when trying to instantiate executor:") app_log.debug(tb) error_msg = tb if debug_mode else str(ex) - return datasvc.generate_node_result( + node_result = datasvc.generate_node_result( node_id=node_id, - node_name=node_name, end_time=datetime.now(timezone.utc), status=RESULT_STATUS.FAILED, error=error_msg, ) + return node_result - # Run the task on the executor and register any failures. + # run the task on the executor and register any failures try: app_log.debug(f"Executing task {node_name}") assembled_callable = partial(wrapper_fn, serialized_callable, call_before, call_after) + execute_callable = partial( + executor.execute, + function=assembled_callable, + args=inputs["args"], + kwargs=inputs["kwargs"], + dispatch_id=dispatch_id, + results_dir=results_dir, + node_id=node_id, + ) - # Note: Executor proxy monitors the executors instances and watches the send and receive queues of the executor. + # Start listening for messages from the plugin asyncio.create_task(executor_proxy.watch(dispatch_id, node_id, executor)) output, stdout, stderr, status = await executor._execute( @@ -259,7 +201,6 @@ async def _run_task( node_result = datasvc.generate_node_result( node_id=node_id, - node_name=node_name, end_time=datetime.now(timezone.utc), status=status, output=output, @@ -274,28 +215,28 @@ async def _run_task( error_msg = tb if debug_mode else str(ex) node_result = datasvc.generate_node_result( node_id=node_id, - node_name=node_name, end_time=datetime.now(timezone.utc), status=RESULT_STATUS.FAILED, error=error_msg, ) + app_log.debug(f"Node result: {node_result}") return node_result # Domain: runner -def _gather_deps(result_object: Result, node_id: int) -> Tuple[List, List]: +async def _gather_deps(dispatch_id: str, node_id: int) -> Tuple[List, List]: """Assemble deps for a node into the final call_before and call_after""" - deps = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")["deps"] + deps_attrs = await datasvc.electron.get( + dispatch_id, node_id, ["deps", "call_before", "call_after"] + ) + + deps = deps_attrs["deps"] # Assemble call_before and call_after from all the deps - call_before_objs_json = result_object.lattice.transport_graph.get_node_value( - node_id, "metadata" - )["call_before"] - call_after_objs_json = result_object.lattice.transport_graph.get_node_value( - node_id, "metadata" - )["call_after"] + call_before_objs_json = deps_attrs["call_before"] + call_after_objs_json = deps_attrs["call_after"] call_before = [] call_after = [] @@ -324,97 +265,19 @@ def _gather_deps(result_object: Result, node_id: int) -> Tuple[List, List]: return call_before, call_after -async def _cancel_task( - dispatch_id: str, task_id: int, executor, executor_data: Dict, job_handle: str -) -> Union[Any, Literal[False]]: - """ - Cancel the task currently being executed by the executor - - Arg(s) - dispatch_id: Dispatch ID - task_id: Task ID of the electron in transport graph to be cancelled - executor: Covalent executor currently being used to execute the task - executor_data: Executor configuration arguments - job_handle: Unique identifier assigned to the task by the backend running the job - - Return(s) - cancel_job_result: Status of the job cancellation action - - """ - app_log.debug(f"Cancel task {task_id} using executor {executor}, {executor_data}") - app_log.debug(f"job_handle: {job_handle}") - - try: - executor = get_executor( - executor=executor, loop=asyncio.get_running_loop(), cancel_pool=_cancel_threadpool - ) - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - cancel_job_result = await executor._cancel(task_metadata, json.loads(job_handle)) - - except Exception as ex: - app_log.debug(f"Exception when cancel task {dispatch_id}:{task_id}: {ex}") - cancel_job_result = False - - await set_cancel_result(dispatch_id, task_id, cancel_job_result) - return cancel_job_result - - -def to_cancel_kwargs( - index: int, node_id: int, node_metadata: List[dict], job_metadata: List[dict] -) -> dict: - """ - Convert node_metadata for a given node `node_id` into a dictionary - - Arg(s) - index: Index into the node_metadata list - node_id: Node ID - node_metadata: List of node metadata attributes - job_metadata: List of metadata for the current job - - Return(s) - Node metadata dictionary - """ - return { - "task_id": node_id, - "executor": node_metadata[index]["executor"], - "executor_data": node_metadata[index]["executor_data"], - "job_handle": job_metadata[index]["job_handle"], - } - - -async def cancel_tasks(dispatch_id: str, task_ids: List[int]) -> None: - """ - Request all tasks with `task_ids` to be cancelled in the workflow identified by `dispatch_id` - - Arg(s) - dispatch_id: Dispatch ID of the workflow - task_ids: List of task ids to be cancelled - - Return(s) - None - """ - job_metadata = await get_jobs_metadata(dispatch_id, task_ids) - node_metadata = _get_metadata_for_nodes(dispatch_id, task_ids) - - cancel_task_kwargs = [ - to_cancel_kwargs(i, x, node_metadata, job_metadata) for i, x in enumerate(task_ids) - ] - - for kwargs in cancel_task_kwargs: - asyncio.create_task(_cancel_task(dispatch_id, **kwargs)) - - -def _get_metadata_for_nodes(dispatch_id: str, node_ids: list) -> List[Any]: - """ - Returns all the metadata associated with the node(s) for the workflow identified by `dispatch_id` - - Arg(s) - dispatch_id: Dispatch ID of the workflow - node_ids: List of node ids from the workflow to retrieve the metadata for - - Return(s) - List of node metadata for the given `node_ids` - """ - res = datasvc.get_result_object(dispatch_id) - tg = res.lattice.transport_graph - return list(map(lambda x: tg.get_node_value(x, "metadata"), node_ids)) +# Domain: runner +async def run_abstract_task( + dispatch_id: str, + node_id: int, + node_name: str, + abstract_inputs: Dict, + selected_executor: Any, +) -> None: + node_result = await _run_abstract_task( + dispatch_id=dispatch_id, + node_id=node_id, + node_name=node_name, + abstract_inputs=abstract_inputs, + selected_executor=selected_executor, + ) + await datasvc.update_node_result(dispatch_id, node_result) diff --git a/covalent_dispatcher/_core/runner_modules/cancel.py b/covalent_dispatcher/_core/runner_modules/cancel.py new file mode 100644 index 0000000000..107c4eb654 --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/cancel.py @@ -0,0 +1,150 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Functions for cancelling jobs +""" + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from typing import Any, List + +from covalent._shared_files import logger +from covalent._shared_files.util_classes import RESULT_STATUS + +from .. import data_manager as datasvc +from ..data_modules import job_manager +from .utils import get_executor + +app_log = logger.app_log + +# Dedicated thread pool for invoking non-async Executor.cancel() +_cancel_threadpool = ThreadPoolExecutor() + +# Collects asyncio task futures +_background_tasks = set() + + +async def _cancel_task( + dispatch_id: str, task_id: int, selected_executor: List, job_handle: str +) -> None: + """ + Cancel the task currently being executed by the executor + + Arg(s) + dispatch_id: Dispatch ID + task_id: Task ID of the electron in transport graph to be cancelled + executor: Covalent executor currently being used to execute the task + executor_data: Executor configuration arguments + job_handle: Unique identifier assigned to the task by the backend running the job + + Return(s) + cancel_job_result: Status of the job cancellation action + """ + app_log.debug(f"Cancel task {task_id} using executor {selected_executor}") + app_log.debug(f"job_handle: {job_handle}") + + try: + executor = get_executor( + node_id=task_id, + selected_executor=selected_executor, + loop=asyncio.get_running_loop(), + pool=_cancel_threadpool, + ) + + task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} + + cancel_job_result = await executor._cancel(task_metadata, json.loads(job_handle)) + except Exception as ex: + app_log.debug(f"Exception when cancel task {dispatch_id}:{task_id}: {ex}") + cancel_job_result = False + + if cancel_job_result is True: + await job_manager.set_job_status(dispatch_id, task_id, str(RESULT_STATUS.CANCELLED)) + app_log.debug(f"Cancelled task {dispatch_id}:{task_id}") + + +def _to_cancel_kwargs( + index: int, node_id: int, node_metadata: List[dict], job_metadata: List[dict] +) -> dict: + """ + Convert node_metadata for a given node `node_id` into a dictionary + + Arg(s) + index: Index into the node_metadata list + node_id: Node ID + node_metadata: List of node metadata attributes + job_metadata: List of metadata for the current job + + Return(s) + Node metadata dictionary + """ + selected_executor = [node_metadata[index]["executor"], node_metadata[index]["executor_data"]] + return { + "task_id": node_id, + "selected_executor": selected_executor, + "job_handle": job_metadata[index]["job_handle"], + } + + +async def cancel_tasks(dispatch_id: str, task_ids: List[int]) -> None: + """ + Request all tasks with `task_ids` to be cancelled in the workflow identified by `dispatch_id` + + Arg(s) + dispatch_id: Dispatch ID of the workflow + task_ids: List of task ids to be cancelled + + Return(s) + None + """ + job_metadata = await job_manager.get_jobs_metadata(dispatch_id, task_ids) + node_metadata = await _get_metadata_for_nodes(dispatch_id, task_ids) + app_log.debug(f"node metadata: {node_metadata}") + app_log.debug(f"job metadata: {job_metadata}") + cancel_task_kwargs = [ + _to_cancel_kwargs(i, x, node_metadata, job_metadata) for i, x in enumerate(task_ids) + ] + + for kwargs in cancel_task_kwargs: + fut = asyncio.create_task(_cancel_task(dispatch_id, **kwargs)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + +async def _get_metadata_for_nodes(dispatch_id: str, node_ids: list) -> List[Any]: + """ + Returns all the metadata associated with the node(s) for the workflow identified by `dispatch_id` + + Arg(s) + dispatch_id: Dispatch ID of the workflow + node_ids: List of node ids from the workflow to retrive the metadata for + + Return(s) + List of node metadata for the given `node_ids` + """ + + attrs = await datasvc.electron.get_bulk( + dispatch_id, + node_ids, + ["executor", "executor_data"], + ) + return attrs diff --git a/covalent_dispatcher/_core/runner_modules/executor_proxy.py b/covalent_dispatcher/_core/runner_modules/executor_proxy.py index 17ae83c235..6fe837bca7 100644 --- a/covalent_dispatcher/_core/runner_modules/executor_proxy.py +++ b/covalent_dispatcher/_core/runner_modules/executor_proxy.py @@ -18,7 +18,7 @@ # # Relief from the License may be granted by purchasing a commercial license. -""" Monitor executor instances.""" +""" Monitor executor instances """ from typing import Any @@ -27,7 +27,13 @@ from covalent.executor.base import _AbstractBaseExecutor as _ABE from covalent.executor.utils import Signals -from ..data_modules import job_manager +from .jobs import ( + get_cancel_requested, + get_job_status, + get_version_info, + put_job_handle, + put_job_status, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -36,50 +42,11 @@ _getters = {} -async def _get_cancel_requested(dispatch_id: str, task_id: int): - """ - Query the database for the task's cancellation status - - Arg(s) - dispatch_id: Dispatch ID of the lattice - task_id: ID of the task within the lattice - - Return(s) - Cancellation status of the task - - """ - # Don't hit the DB for post-processing task - if task_id < 0: - return False - - app_log.debug(f"Get _handle_requested for executor {dispatch_id}:{task_id}") - job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) - app_log.debug(f"Job record: {job_records[0]}") - return job_records[0]["cancel_requested"] - - -async def _put_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> bool: - """ - Store the job handle of the task returned by the backend in the database - - Arg(s) - dispatch_id: Dispatch ID of the lattice - task_id: ID of the task within the lattice - job_handle: Unique identifier of the task returned by the execution backend - - Return(s) - True - """ - # Don't hit the DB for post-processing task - if task_id < 0: - return False - app_log.debug(f"Put job_handle for executor {dispatch_id}:{task_id}") - await job_manager.set_job_handle(dispatch_id, task_id, job_handle) - return True - - -_putters["job_handle"] = _put_job_handle -_getters["cancel_requested"] = _get_cancel_requested +_putters["job_handle"] = put_job_handle +_putters["job_status"] = put_job_status +_getters["cancel_requested"] = get_cancel_requested +_getters["job_status"] = get_job_status +_getters["version_info"] = get_version_info async def _handle_message( diff --git a/covalent_dispatcher/_core/runner_modules/jobs.py b/covalent_dispatcher/_core/runner_modules/jobs.py new file mode 100644 index 0000000000..dc0529dd11 --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/jobs.py @@ -0,0 +1,130 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" Handlers for the executor proxy """ + + +from covalent._shared_files import logger +from covalent._shared_files.util_classes import Status +from covalent_dispatcher._core import data_manager as datamgr + +from .. import data_manager as datasvc +from ..data_modules import job_manager +from .utils import get_executor + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + + +async def get_cancel_requested(dispatch_id: str, task_id: int): + """ + Query the database for the task's cancellation status + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Return(s) + Canellation status of the task + """ + + app_log.debug(f"Get _handle_requested for task {dispatch_id}:{task_id}") + job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) + app_log.debug(f"Job record: {job_records[0]}") + return job_records[0]["cancel_requested"] + + +async def get_version_info(dispatch_id: str, task_id: int): + """ + Query the database for the dispatch version information + + Arg: + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Returns: + {"python": python_version, "covalent": covalent_version} + """ + + data = await datamgr.lattice.get(dispatch_id, ["python_version", "covalent_version"]) + + return { + "python": data["python_version"], + "covalent": data["covalent_version"], + } + + +async def get_job_status(dispatch_id: str, task_id: int) -> Status: + """ + Queries the job state for (dispatch_id, task_id) + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Return(s) + Status + """ + app_log.debug(f"Get for task {dispatch_id}:{task_id}") + job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) + app_log.debug(f"Job record: {job_records[0]}") + return Status(job_records[0]["status"]) + + +async def put_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> bool: + """ + Store the job handle of the task returned by the backend in the database + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + job_handle: Unique identifier of the task returned by the execution backend + + Return(s) + True + """ + app_log.debug(f"Put job_handle for executor {dispatch_id}:{task_id}") + await job_manager.set_job_handle(dispatch_id, task_id, job_handle) + return True + + +async def put_job_status(dispatch_id: str, task_id: int, status: Status) -> bool: + """ + Mark the job for (dispatch_id, task_id) as cancelled + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + job_status: A `Status` type representing the job status + + Return(s) + True + """ + app_log.debug(f"Put cancel result for task {dispatch_id}:{task_id}") + executor_attrs = await datasvc.electron.get( + dispatch_id, task_id, ["executor", "executor_data"] + ) + selected_executor = [executor_attrs["executor"], executor_attrs["executor_data"]] + executor = get_executor(task_id, selected_executor, None, None) + if executor.validate_status(status): + await job_manager.set_job_status(dispatch_id, task_id, str(status)) + return True + else: + return False diff --git a/covalent_dispatcher/_core/runner_modules/utils.py b/covalent_dispatcher/_core/runner_modules/utils.py new file mode 100644 index 0000000000..18886ec6e0 --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/utils.py @@ -0,0 +1,47 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Defines the core functionality of the runner +""" + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent.executor import _executor_manager +from covalent.executor.base import AsyncBaseExecutor + +app_log = logger.app_log +log_stack_info = logger.log_stack_info +debug_mode = get_config("sdk.log_level") == "debug" + + +def get_executor(node_id, selected_executor, loop=None, pool=None) -> AsyncBaseExecutor: + # Instantiate the executor from JSON + + short_name, object_dict = selected_executor + + app_log.debug(f"Running task {node_id} using executor {short_name}, {object_dict}") + + # the executor is determined during scheduling and provided in the execution metadata + executor = _executor_manager.get_executor(short_name) + executor.from_dict(object_dict) + executor._init_runtime(loop=loop, cancel_pool=pool) + + return executor diff --git a/covalent_dispatcher/_db/load.py b/covalent_dispatcher/_db/load.py deleted file mode 100644 index 829021ff1e..0000000000 --- a/covalent_dispatcher/_db/load.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the GNU Affero General Public License 3.0 (the "License"). -# A copy of the License may be obtained with this software package or at -# -# https://www.gnu.org/licenses/agpl-3.0.en.html -# -# Use of this file is prohibited except in compliance with the License. Any -# modifications or derivative works of this file must retain this copyright -# notice, and modified files must contain a notice indicating that they have -# been altered from the originals. -# -# Covalent is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. -# -# Relief from the License may be granted by purchasing a commercial license. - -"""Functions to load results from the database.""" - - -from typing import Dict, Union - -from covalent import lattice -from covalent._results_manager.result import Result -from covalent._shared_files import logger -from covalent._shared_files.util_classes import Status -from covalent._workflow.transport import TransportableObject -from covalent._workflow.transport import _TransportGraph as SDKGraph -from covalent_dispatcher._dal.electron import ASSET_KEYS as ELECTRON_ASSETS -from covalent_dispatcher._dal.electron import METADATA_KEYS as ELECTRON_META -from covalent_dispatcher._dal.result import get_result_object -from covalent_dispatcher._dal.tg import _TransportGraph as SRVGraph -from covalent_dispatcher._object_store.local import local_store - -from .datastore import workflow_db -from .models import Electron, Lattice - -app_log = logger.app_log -log_stack_info = logger.log_stack_info - -NODE_ATTRIBUTES = ELECTRON_META.union(ELECTRON_ASSETS) -SDK_NODE_META_KEYS = { - "executor", - "executor_data", - "deps", - "call_before", - "call_after", -} - - -def load_file(storage_path, filename): - return local_store.load_file(storage_path, filename) - - -def _to_client_graph(srv_graph: SRVGraph) -> SDKGraph: - """Render a SDK _TransportGraph from a server-side graph""" - - sdk_graph = SDKGraph() - - sdk_graph._graph = srv_graph.get_internal_graph_copy() - for node_id in srv_graph._graph.nodes: - attrs = list(sdk_graph._graph.nodes[node_id].keys()) - for k in attrs: - del sdk_graph._graph.nodes[node_id][k] - attributes = {} - for k in NODE_ATTRIBUTES: - if k not in SDK_NODE_META_KEYS: - attributes[k] = srv_graph.get_node_value(node_id, k) - if srv_graph.get_node_value(node_id, "type") == "parameter": - attributes["value"] = srv_graph.get_node_value(node_id, "value") - attributes["output"] = srv_graph.get_node_value(node_id, "output") - - node_meta = {k: srv_graph.get_node_value(node_id, k) for k in SDK_NODE_META_KEYS} - attributes["metadata"] = node_meta - - for k, v in attributes.items(): - sdk_graph.set_node_value(node_id, k, v) - - sdk_graph.lattice_metadata = {} - - return sdk_graph - - -def _result_from(lattice_record: Lattice) -> Result: - """Re-hydrate result object from the lattice record. - - Args: - lattice_record: Lattice record to re-hydrate from. - - Returns: - Result object. - - """ - - srv_res = get_result_object(lattice_record.dispatch_id, bare=False) - - function = srv_res.lattice.get_value("workflow_function") - - function_string = srv_res.lattice.get_value("workflow_function_string") - function_docstring = srv_res.lattice.get_value("doc") - - executor_data = srv_res.lattice.get_value("executor_data") - - workflow_executor_data = srv_res.lattice.get_value("workflow_executor_data") - - inputs = srv_res.lattice.get_value("inputs") - named_args = srv_res.lattice.get_value("named_args") - named_kwargs = srv_res.lattice.get_value("named_kwargs") - error = srv_res.get_value("error") - - transport_graph = _to_client_graph(srv_res.lattice.transport_graph) - - output = srv_res.get_value("result") - deps = srv_res.lattice.get_value("deps") - call_before = srv_res.lattice.get_value("call_before") - call_after = srv_res.lattice.get_value("call_after") - cova_imports = srv_res.lattice.get_value("cova_imports") - lattice_imports = srv_res.lattice.get_value("lattice_imports") - - name = lattice_record.name - executor = lattice_record.executor - workflow_executor = lattice_record.workflow_executor - num_nodes = lattice_record.electron_num - - attributes = { - "workflow_function": function, - "workflow_function_string": function_string, - "__name__": name, - "__doc__": function_docstring, - "metadata": { - "executor": executor, - "executor_data": executor_data, - "workflow_executor": workflow_executor, - "workflow_executor_data": workflow_executor_data, - "deps": deps, - "call_before": call_before, - "call_after": call_after, - }, - "inputs": inputs, - "named_args": named_args, - "named_kwargs": named_kwargs, - "transport_graph": transport_graph, - "cova_imports": cova_imports, - "lattice_imports": lattice_imports, - "post_processing": False, - "electron_outputs": {}, - "_bound_electrons": {}, - } - - def dummy_function(x): - return x - - lat = lattice(dummy_function) - lat.__dict__ = attributes - - result = Result( - lat, - dispatch_id=lattice_record.dispatch_id, - ) - result._root_dispatch_id = lattice_record.root_dispatch_id - result._status = Status(lattice_record.status) - result._error = error or "" - result._inputs = inputs - result._start_time = lattice_record.started_at - result._end_time = lattice_record.completed_at - result._result = output if output is not None else TransportableObject(None) - result._num_nodes = num_nodes - return result - - -def get_result_object_from_storage(dispatch_id: str) -> Result: - """Get the result object from the database. - - Args: - dispatch_id: The dispatch id of the result object to load. - - Returns: - The result object. - - """ - with workflow_db.session() as session: - lattice_record = session.query(Lattice).where(Lattice.dispatch_id == dispatch_id).first() - if not lattice_record: - app_log.debug(f"No result object found for dispatch {dispatch_id}") - raise RuntimeError(f"No result object found for dispatch {dispatch_id}") - - return _result_from(lattice_record) - - -def electron_record(dispatch_id: str, node_id: str) -> Dict: - """Get electron record for a given dispatch if and node id. - - Args: - dispatch_id: Dispatch id for lattice. - node_id: Node id of the electron. - - Returns: - Electron record. - - """ - with workflow_db.session() as session: - return ( - session.query(Lattice, Electron) - .filter(Lattice.id == Electron.parent_lattice_id) - .filter(Lattice.dispatch_id == dispatch_id) - .filter(Electron.transport_graph_node_id == node_id) - .first() - .Electron.__dict__ - ) - - -def sublattice_dispatch_id(electron_id: int) -> Union[str, None]: - """Get the dispatch id of the sublattice for a given electron id. - - Args: - electron_id: Electron ID. - - Returns: - Dispatch id of sublattice. None, if the electron is not a sublattice. - - """ - with workflow_db.session() as session: - if record := (session.query(Lattice).filter(Lattice.electron_id == electron_id).first()): - return record.dispatch_id diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index b96409cf40..99c44fa94c 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -1,154 +1,244 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the GNU Affero General Public License 3.0 (the "License"). -# A copy of the License may be obtained with this software package or at -# -# https://www.gnu.org/licenses/agpl-3.0.en.html -# -# Use of this file is prohibited except in compliance with the License. Any -# modifications or derivative works of this file must retain this copyright -# notice, and modified files must contain a notice indicating that they have -# been altered from the originals. -# -# Covalent is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. -# -# Relief from the License may be granted by purchasing a commercial license. - -import codecs -import json -from typing import Optional -from uuid import UUID - -import cloudpickle as pickle -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import JSONResponse - -import covalent_dispatcher as dispatcher -from covalent._results_manager.result import Result -from covalent._shared_files import logger - -from .._db.datastore import workflow_db -from .._db.load import _result_from -from .._db.models import Lattice - -app_log = logger.app_log -log_stack_info = logger.log_stack_info - -router: APIRouter = APIRouter() - - -@router.post("/submit") -async def submit(request: Request, disable_run: bool = False) -> UUID: - """ - Function to accept the submit request of - new dispatch and return the dispatch id - back to the client. - - Args: - disable_run: Whether to disable the execution of this lattice - - Returns: - dispatch_id: The dispatch id in a json format - returned as a Fast API Response object - """ - try: - data = await request.json() - data = json.dumps(data).encode("utf-8") - - return await dispatcher.run_dispatcher(data, disable_run) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to submit workflow: {e}", - ) from e - - -@router.post("/redispatch") -async def redispatch(request: Request, is_pending: bool = False) -> str: - """Endpoint to redispatch a workflow.""" - try: - data = await request.json() - dispatch_id = data["dispatch_id"] - json_lattice = data["json_lattice"] - electron_updates = data["electron_updates"] - reuse_previous_results = data["reuse_previous_results"] - app_log.debug( - f"Unpacked redispatch request for {dispatch_id}. reuse_previous_results: {reuse_previous_results}, electron_updates: {electron_updates}" - ) - return await dispatcher.run_redispatch( - dispatch_id, json_lattice, electron_updates, reuse_previous_results, is_pending - ) - - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to redispatch workflow: {e}", - ) from e - - -@router.post("/cancel") -async def cancel(request: Request) -> str: - """ - Function to accept the cancel request of - a dispatch. - - Args: - None - - Returns: - Fast API Response object confirming that the dispatch - has been cancelled. - """ - - data = await request.json() - - dispatch_id = data["dispatch_id"] - task_ids = data["task_ids"] - - await dispatcher.cancel_running_dispatch(dispatch_id, task_ids) - if task_ids: - return f"Cancelled tasks {task_ids} in dispatch {dispatch_id}." - else: - return f"Dispatch {dispatch_id} cancelled." - - -@router.get("/result/{dispatch_id}") -async def get_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -): - with workflow_db.session() as session: - lattice_record = session.query(Lattice).where(Lattice.dispatch_id == dispatch_id).first() - status = lattice_record.status if lattice_record else None - if not lattice_record: - return JSONResponse( - status_code=404, - content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, - ) - if not wait or status in [ - str(Result.COMPLETED), - str(Result.FAILED), - str(Result.CANCELLED), - str(Result.POSTPROCESSING_FAILED), - str(Result.PENDING_POSTPROCESSING), - ]: - output = { - "id": dispatch_id, - "status": lattice_record.status, - } - if not status_only: - output["result"] = codecs.encode( - pickle.dumps(_result_from(lattice_record)), "base64" - ).decode() - return output - - return JSONResponse( - status_code=503, - content={ - "message": "Result not ready to read yet. Please wait for a couple of seconds." - }, - headers={"Retry-After": "2"}, - ) +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + + +"""Endpoints for dispatch management""" + +import asyncio +import json +from contextlib import asynccontextmanager +from typing import Optional, Union +from uuid import UUID + +from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +import covalent_dispatcher.entry_point as dispatcher +from covalent._shared_files import logger +from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core import dispatcher as core_dispatcher +from covalent_dispatcher._core import runner_ng as core_runner + +from .._dal.exporters.result import export_result_manifest +from .._dal.result import Result, get_result_object +from .._db.dispatchdb import DispatchDB +from .heartbeat import Heartbeat +from .models import ExportResponseSchema + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +_background_tasks = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize global variables""" + + heartbeat = Heartbeat() + fut = asyncio.create_task(heartbeat.start()) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + # Runner event queue and listener + core_runner._job_events = asyncio.Queue() + core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events()) + + # Dispatcher event queue and listener + core_dispatcher._global_status_queue = asyncio.Queue() + core_dispatcher._global_event_listener = asyncio.create_task( + core_dispatcher._node_event_listener() + ) + + yield + + core_dispatcher._global_event_listener.cancel() + core_runner._job_event_listener.cancel() + + +@router.post("/dispatch/submit") +async def submit(request: Request) -> UUID: + """ + Function to accept the submit request of + new dispatch and return the dispatch id + back to the client. + + Args: + None + + Returns: + dispatch_id: The dispatch id in a json format + returned as a Fast API Response object. + """ + try: + data = await request.json() + data = json.dumps(data).encode("utf-8") + return await dispatcher.make_dispatch(data) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.post("/dispatch/cancel") +async def cancel(request: Request) -> str: + """ + Function to accept the cancel request of + a dispatch. + + Args: + None + + Returns: + Fast API Response object confirming that the dispatch + has been cancelled. + """ + + data = await request.json() + + dispatch_id = data["dispatch_id"] + task_ids = data["task_ids"] + + await dispatcher.cancel_running_dispatch(dispatch_id, task_ids) + if task_ids: + return f"Cancelled tasks {task_ids} in dispatch {dispatch_id}." + else: + return f"Dispatch {dispatch_id} cancelled." + + +@router.get("/db-path") +def db_path() -> str: + db_path = DispatchDB()._dbpath + return json.dumps(db_path) + + +@router.post("/dispatch/register") +async def register( + manifest: ResultSchema, parent_dispatch_id: Union[str, None] = None +) -> ResultSchema: + try: + return await dispatcher.register_dispatch(manifest, parent_dispatch_id) + except Exception as e: + app_log.debug(f"Exception in register: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.post("/dispatch/register/{dispatch_id}") +async def register_redispatch( + manifest: ResultSchema, + dispatch_id: str, + reuse_previous_results: bool = False, +): + try: + return await dispatcher.register_redispatch( + manifest, + dispatch_id, + reuse_previous_results, + ) + except Exception as e: + app_log.debug(f"Exception in register_redispatch: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.put("/dispatch/start/{dispatch_id}") +async def start(dispatch_id: str): + try: + fut = asyncio.create_task(dispatcher.start_dispatch(dispatch_id)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + return dispatch_id + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to start workflow: {e}", + ) from e + + +@router.get("/dispatch/export/{dispatch_id}") +async def export_result( + dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False +) -> ExportResponseSchema: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + _export_result_sync, + dispatch_id, + wait, + status_only, + ) + + +def _export_result_sync( + dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False +) -> ExportResponseSchema: + result_object = _try_get_result_object(dispatch_id) + if not result_object: + return JSONResponse( + status_code=404, + content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, + ) + status = str(result_object.get_value("status", refresh=False)) + + if not wait or status in [ + str(RESULT_STATUS.COMPLETED), + str(RESULT_STATUS.FAILED), + str(RESULT_STATUS.CANCELLED), + ]: + output = { + "id": dispatch_id, + "status": status, + } + if not status_only: + output["result_export"] = export_result_manifest(dispatch_id) + + return output + + response = JSONResponse( + status_code=503, + content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, + headers={"Retry-After": "2"}, + ) + return response + + +def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: + try: + res = get_result_object( + dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] + ) + except KeyError: + res = None + return res diff --git a/covalent_dispatcher/_service/assets.py b/covalent_dispatcher/_service/assets.py new file mode 100644 index 0000000000..a34442c84b --- /dev/null +++ b/covalent_dispatcher/_service/assets.py @@ -0,0 +1,436 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Endpoints for uploading and downloading workflow assets""" + +import mmap +import os +import shutil +from functools import lru_cache +from typing import BinaryIO, Tuple, Union + +from fastapi import APIRouter, Header, HTTPException, UploadFile +from fastapi.responses import StreamingResponse +from furl import furl + +from covalent._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES +from covalent._serialize.lattice import ASSET_TYPES as LATTICE_ASSET_TYPES +from covalent._serialize.result import ASSET_TYPES as RESULT_ASSET_TYPES +from covalent._serialize.result import AssetType +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent._workflow.transportable_object import TOArchiveUtils + +from .._dal.result import get_result_object +from .._db.datastore import workflow_db +from .models import ( + AssetRepresentation, + DispatchAssetKey, + ElectronAssetKey, + LatticeAssetKey, + digest_pattern, + digest_regex, + range_pattern, + range_regex, +) + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +_background_tasks = set() + +LRU_CACHE_SIZE = get_config("dispatcher.asset_cache_size") + + +@router.get("/assets/{dispatch_id}/node/{node_id}/{key}") +def get_node_asset( + dispatch_id: str, + node_id: int, + key: ElectronAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for node {dispatch_id}:{node_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + + node = result_object.lattice.transport_graph.get_node(node_id) + with workflow_db.session() as session: + asset = node.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + + except Exception as e: + app_log.debug(e) + raise + + +@router.get("/assets/{dispatch_id}/dispatch/{key}") +def get_dispatch_asset( + dispatch_id: str, + key: DispatchAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for dispatch {dispatch_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + with workflow_db.session() as session: + asset = result_object.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + except Exception as e: + app_log.debug(e) + raise + + +@router.get("/assets/{dispatch_id}/lattice/{key}") +def get_lattice_asset( + dispatch_id: str, + key: LatticeAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested lattice asset {key.value} ([{start_byte}:{end_byte}])for dispatch {dispatch_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + + with workflow_db.session() as session: + asset = result_object.lattice.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + + except Exception as e: + app_log.debug(e) + raise e + + +@router.post("/assets/{dispatch_id}/node/{node_id}/{key}") +def upload_node_asset( + dispatch_id: str, + node_id: int, + key: ElectronAssetKey, + asset_file: UploadFile, + content_length: int = Header(), + digest: Union[str, None] = Header(default=None, regex=digest_regex), +): + app_log.debug(f"Requested asset {key} for node {dispatch_id}:{node_id}") + + try: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + node = result_object.lattice.transport_graph.get_node(node_id) + with workflow_db.session() as session: + asset = node.get_asset(key=key.value, session=session) + app_log.debug(f"Asset uri {asset.internal_uri}") + + # Update asset metadata + update = _get_asset_metadata_update(content_length, digest) + node.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated node asset {dispatch_id}:{node_id}:{key}") + + # Copy the tempfile to object store + _copy_file_obj(asset_file.file, asset.internal_uri) + + return f"Uploaded file to {asset.internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +@router.post("/assets/{dispatch_id}/dispatch/{key}") +def upload_dispatch_asset( + dispatch_id: str, + key: DispatchAssetKey, + asset_file: UploadFile, + content_length: int = Header(), + digest: Union[str, None] = Header(default=None, regex=digest_regex), +): + try: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + with workflow_db.session() as session: + asset = result_object.get_asset(key=key.value, session=session) + + # Update asset metadata + update = _get_asset_metadata_update(content_length, digest) + result_object.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated size for dispatch asset {dispatch_id}:{key}") + + # Copy the tempfile to object store + _copy_file_obj(asset_file.file, asset.internal_uri) + + return f"Uploaded file to {asset.internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +@router.post("/assets/{dispatch_id}/lattice/{key}") +def upload_lattice_asset( + dispatch_id: str, + key: LatticeAssetKey, + asset_file: UploadFile, + content_length: int = Header(), + digest: Union[str, None] = Header(default=None, regex=digest_regex), +): + try: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + + with workflow_db.session() as session: + asset = result_object.lattice.get_asset(key=key.value, session=session) + + # Update asset metadata + update = _get_asset_metadata_update(content_length, digest) + result_object.lattice.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated size for lattice asset {dispatch_id}:{key}") + + # Copy the tempfile to object store + _copy_file_obj(asset_file.file, asset.internal_uri) + + return f"Uploaded file to {asset.internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +def _copy_file_obj(src_fileobj: BinaryIO, dest_url: str): + dest_path = str(furl(dest_url).path) + with open(dest_path, "wb") as dest_fileobj: + shutil.copyfileobj(src_fileobj, dest_fileobj) + + +def _generate_file_slice(file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + """Generator of a byte slice from a file. + + Args: + file_url: A file:/// type URL pointing to the file + start_byte: The beginning of the byte range + end_byte: The end of the byte range, or -1 to select [start_byte:] + chunk_size: The size of each chunk + + Returns: + Yields chunks of size <= chunk_size + """ + byte_pos = start_byte + file_path = str(furl(file_url).path) + with open(file_path, "rb") as f: + f.seek(start_byte) + if end_byte < 0: + for chunk in f: + yield chunk + else: + while byte_pos + chunk_size < end_byte: + byte_pos += chunk_size + yield f.read(chunk_size) + yield f.read(end_byte - byte_pos) + + +def _extract_byte_range(byte_range_header: str) -> Tuple[int, int]: + """Extract the byte range from a range request header.""" + start_byte = 0 + end_byte = -1 + match = range_pattern.match(byte_range_header) + start = match.group(1) + end = match.group(2) + start_byte = int(start) + if end: + end_byte = int(end) + + return start_byte, end_byte + + +def _extract_checksum(digest_header: str) -> Tuple[str, str]: + match = digest_pattern.match(digest_header) + alg = match.group(0) + checksum = match.group(1) + return alg, checksum + + +# Helpers for TransportableObject + + +def _get_tobj_string_offsets(file_url: str) -> Tuple[int, int]: + """Get the byte range for the str rep of a stored TObj. + + For a first implementation we just query the filesystem directly. + + Args: + file_url: A file:/// URL pointing to the TransportableObject + + Returns: + (start_byte, end_byte) + """ + + file_path = str(furl(file_url).path) + filelen = os.path.getsize(file_path) + with open(file_path, "rb+") as f: + with mmap.mmap(f.fileno(), filelen) as mm: + # TOArchiveUtils operates on byte arrays + return TOArchiveUtils.string_byte_range(mm) + + +def _get_tobj_pickle_offsets(file_url: str) -> Tuple[int, int]: + """Get the byte range for the picklebytes of a stored TObj. + + For a first implementation we just query the filesystem directly. + + Args: + file_url: A file:/// URL pointing to the TransportableObject + + Returns: + (start_byte, -1) + """ + + file_path = str(furl(file_url).path) + filelen = os.path.getsize(file_path) + with open(file_path, "rb+") as f: + with mmap.mmap(f.fileno(), filelen) as mm: + # TOArchiveUtils operates on byte arrays + return TOArchiveUtils.data_byte_range(mm) + + +# This must only be used for static data as we don't have yet any +# intelligent invalidation logic. +@lru_cache(maxsize=LRU_CACHE_SIZE) +def get_cached_result_object(dispatch_id: str): + try: + with workflow_db.session() as session: + srv_res = get_result_object(dispatch_id, bare=False, session=session) + app_log.debug(f"Caching result {dispatch_id}") + + # Prepopulate asset maps to avoid DB lookups + + srv_res.populate_asset_map(session) + srv_res.lattice.populate_asset_map(session) + + tg = srv_res.lattice.transport_graph + g = tg.get_internal_graph_copy() + for node_id in g.nodes(): + node = tg.get_node(node_id, session) + node.populate_asset_map(session) + except KeyError: + raise HTTPException( + status_code=400, + detail=f"The requested dispatch ID {dispatch_id} was not found.", + ) + + return srv_res + + +def _get_asset_metadata_update(content_length, digest): + update = {"size": content_length} + if digest: + alg, checksum = _extract_checksum(digest) + update["digest_alg"] = alg + update["digest"] = checksum + + return update diff --git a/covalent_ui/heartbeat.py b/covalent_dispatcher/_service/heartbeat.py similarity index 60% rename from covalent_ui/heartbeat.py rename to covalent_dispatcher/_service/heartbeat.py index 22d058e3db..4626f6dcf0 100644 --- a/covalent_ui/heartbeat.py +++ b/covalent_dispatcher/_service/heartbeat.py @@ -19,16 +19,11 @@ # Relief from the License may be granted by purchasing a commercial license. import asyncio -from contextlib import asynccontextmanager from datetime import datetime, timezone -from typing import List import aiofiles -from fastapi import FastAPI from covalent._shared_files.config import get_config -from covalent_ui.api.v1.routes.end_points.summary_routes import get_all_dispatches -from covalent_ui.api.v1.utils.status import Status class Heartbeat: @@ -59,48 +54,3 @@ def stop(): file.write( f"DEAD {datetime.now(tz=timezone.utc).strftime(Heartbeat.TIMESTAMP_FORMAT)}" ) - - -async def cancel_all_with_status(status: Status) -> List[str]: - from covalent_dispatcher._core.dispatcher import cancel_dispatch - - dispatch_ids = [] - page = 0 - count = 100 - - while True: - dispatches = get_all_dispatches( - count=count, - offset=page * count, - status_filter=status, - ) - - dispatch_ids += [dispatch.dispatch_id for dispatch in dispatches.items] - - if dispatches.total_count == page * count + len(dispatches.items): - break - - page += 1 - - for dispatch in dispatch_ids: - await cancel_dispatch(dispatch.dispatch_id) - - return dispatch_ids - - -@asynccontextmanager -async def lifespan(app: FastAPI): - heartbeat = Heartbeat() - asyncio.create_task(heartbeat.start()) - - yield - - for status in [ - Status.NEW_OBJECT, - Status.POSTPROCESSING, - Status.PENDING_POSTPROCESSING, - Status.RUNNING, - ]: - await cancel_all_with_status(status) - - Heartbeat.stop() diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py new file mode 100644 index 0000000000..5d9c1bb84d --- /dev/null +++ b/covalent_dispatcher/_service/models.py @@ -0,0 +1,117 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""FastAPI models for /api/v1/resultv2 endpoints""" + +import re +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from covalent._shared_files.schemas.result import ResultSchema + +# # Copied from _dal +# RESULT_ASSET_KEYS = { +# "inputs", +# "result", +# "error", +# } + +# # Copied from _dal +# LATTICE_ASSET_KEYS = { +# "workflow_function", +# "workflow_function_string", +# "__doc__", +# "named_args", +# "named_kwargs", +# "cova_imports", +# "lattice_imports", +# # metadata +# "executor_data", +# "workflow_executor_data", +# "deps", +# "call_before", +# "call_after", +# } + +# # Copied from _dal +# ELECTRON_ASSET_KEYS = { +# "function", +# "function_string", +# "output", +# "value", +# "error", +# "stdout", +# "stderr", +# # electron metadata +# "deps", +# "call_before", +# "call_after", +# } + +range_regex = "bytes=([0-9]+)-([0-9]*)" +range_pattern = re.compile(range_regex) + +digest_regex = "(sha|sha-256)=([0-9a-f]+)" +digest_pattern = re.compile(digest_regex) + + +class DispatchAssetKey(str, Enum): + result = "result" + error = "error" + + +class LatticeAssetKey(str, Enum): + workflow_function = "workflow_function" + workflow_function_string = "workflow_function_string" + doc = "doc" + inputs = "inputs" + named_args = "named_args" + named_kwargs = "named_kwargs" + deps = "deps" + call_before = "call_before" + call_after = "call_after" + cova_imports = "cova_imports" + lattice_imports = "lattice_imports" + + +class ElectronAssetKey(str, Enum): + function = "function" + function_string = "function_string" + output = "output" + value = "value" + deps = "deps" + error = "error" + stdout = "stdout" + stderr = "stderr" + call_before = "call_before" + call_after = "call_after" + + +class ExportResponseSchema(BaseModel): + id: str + status: str + result_export: Optional[ResultSchema] + + +class AssetRepresentation(str, Enum): + string = "string" + b64pickle = "object" diff --git a/covalent_dispatcher/entry_point.py b/covalent_dispatcher/entry_point.py index c7ef7c5966..e14003cffb 100644 --- a/covalent_dispatcher/entry_point.py +++ b/covalent_dispatcher/entry_point.py @@ -22,9 +22,11 @@ Self-contained entry point for the dispatcher """ -from typing import List +import asyncio +from typing import List, Optional from covalent._shared_files import logger +from covalent._shared_files.schemas.result import ResultSchema from ._core import cancel_dispatch @@ -32,7 +34,7 @@ log_stack_info = logger.log_stack_info -async def run_dispatcher(json_lattice: str, disable_run: bool = False): +async def make_dispatch(json_lattice: str): """ Run the dispatcher from the lattice asynchronously using Dask. Assign a new dispatch id to the result object and return it. @@ -40,47 +42,67 @@ async def run_dispatcher(json_lattice: str, disable_run: bool = False): Args: json_lattice: A JSON-serialized lattice - disable_run: Whether to disable execution of this lattice Returns: dispatch_id: A string containing the dispatch id of current dispatch. """ - from ._core import make_dispatch, run_dispatch + from ._core import make_dispatch dispatch_id = await make_dispatch(json_lattice) - if not disable_run: - run_dispatch(dispatch_id) - app_log.debug(f"Submitted dispatch_id {dispatch_id} to run_workflow.") + app_log.debug(f"Created new dispatch {dispatch_id}") return dispatch_id -async def run_redispatch( - dispatch_id: str, - json_lattice: str, - electron_updates: dict, - reuse_previous_results: bool, - is_pending: bool = False, -): - from ._core import make_derived_dispatch, run_dispatch - - app_log.debug("Running redispatch ...") - if is_pending: - run_dispatch(dispatch_id) - app_log.debug(f"Submitted pending dispatch_id {dispatch_id} to run_dispatch.") - return dispatch_id - - redispatch_id = make_derived_dispatch( - dispatch_id, json_lattice, electron_updates, reuse_previous_results - ) - app_log.debug(f"Redispatch id {redispatch_id} created.") - run_dispatch(redispatch_id) +async def start_dispatch(dispatch_id: str): + """ + Run the dispatcher from the lattice asynchronously using Dask. + Assign a new dispatch id to the result object and return it. + Also save the result in this initial stage to the file mentioned in the result object. + + Args: + json_lattice: A JSON-serialized lattice + + Returns: + dispatch_id: A string containing the dispatch id of current dispatch. + """ + + from ._core import copy_futures, run_dispatch + + # Wait for any pending asset transfers + _fut = copy_futures.get(dispatch_id, None) + if _fut is not None: + # _fut is a concurrent.future.Future, so we need to wrap it in + # an asyncio.Future + app_log.debug(f"Waiting on asset transfers for dispatch {dispatch_id}") + await asyncio.wrap_future(_fut) + + # Idempotent + run_dispatch(dispatch_id) + app_log.debug(f"Running dispatch {dispatch_id}") + + +async def run_dispatcher(json_lattice: str): + """ + Run the dispatcher from the lattice asynchronously using Dask. + Assign a new dispatch id to the result object and return it. + Also save the result in this initial stage to the file mentioned in the result object. + + Args: + json_lattice: A JSON-serialized lattice + + Returns: + dispatch_id: A string containing the dispatch id of current dispatch. + """ + + dispatch_id = await make_dispatch(json_lattice) + await start_dispatch(dispatch_id) - app_log.debug(f"Re-dispatching {dispatch_id} as {redispatch_id}") + app_log.debug("Submitted result object to run_workflow.") - return redispatch_id + return dispatch_id async def cancel_running_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: @@ -98,3 +120,25 @@ async def cancel_running_dispatch(dispatch_id: str, task_ids: List[int] = None) task_ids = [] await cancel_dispatch(dispatch_id, task_ids) + + +async def register_dispatch( + manifest: ResultSchema, parent_dispatch_id: Optional[str] +) -> ResultSchema: + from ._core.data_modules.importer import import_manifest + + return await import_manifest(manifest, parent_dispatch_id, None) + + +async def register_redispatch( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + from ._core.data_modules.importer import import_derived_manifest + + return await import_derived_manifest( + manifest, + parent_dispatch_id, + reuse_previous_results, + ) diff --git a/covalent_ui/api/main.py b/covalent_ui/api/main.py index fc33a3c7a0..c80cb71ced 100644 --- a/covalent_ui/api/main.py +++ b/covalent_ui/api/main.py @@ -38,8 +38,8 @@ from covalent._shared_files import logger from covalent._shared_files.config import get_config +from covalent_dispatcher._service.app import lifespan from covalent_ui.api.v1.routes import routes -from covalent_ui.heartbeat import lifespan file_descriptor = None child_process_id = None diff --git a/covalent_ui/api/v1/data_layer/electron_dal.py b/covalent_ui/api/v1/data_layer/electron_dal.py index e4a6eed3cf..2a13ff9417 100644 --- a/covalent_ui/api/v1/data_layer/electron_dal.py +++ b/covalent_ui/api/v1/data_layer/electron_dal.py @@ -50,7 +50,8 @@ def get_electrons_id(self, dispatch_id, electron_id) -> Electron: Electron.function_filename, Electron.function_string_filename, Electron.executor, - Electron.executor_data_filename, + Electron.executor_data, + # Electron.executor_data_filename, Electron.results_filename, Electron.value_filename, Electron.stdout_filename, diff --git a/covalent_ui/api/v1/data_layer/lattice_dal.py b/covalent_ui/api/v1/data_layer/lattice_dal.py index 17321bdeb4..c6258221d5 100644 --- a/covalent_ui/api/v1/data_layer/lattice_dal.py +++ b/covalent_ui/api/v1/data_layer/lattice_dal.py @@ -94,15 +94,17 @@ def get_lattices_id_storage_file(self, dispatch_id: UUID): Lattice.error_filename, Lattice.function_string_filename, Lattice.executor, - Lattice.executor_data_filename, + Lattice.executor_data, + # Lattice.executor_data_filename, Lattice.workflow_executor, - Lattice.workflow_executor_data_filename, + Lattice.workflow_executor_data, + # Lattice.workflow_executor_data_filename, Lattice.error_filename, Lattice.inputs_filename, Lattice.results_filename, Lattice.storage_type, Lattice.function_filename, - Lattice.transport_graph_filename, + # Lattice.transport_graph_filename, Lattice.started_at.label("started_at"), Lattice.completed_at.label("ended_at"), Lattice.electron_num.label("total_electrons"), diff --git a/covalent_ui/api/v1/database/schema/electron.py b/covalent_ui/api/v1/database/schema/electron.py index 206b123cdf..dcb9918df0 100644 --- a/covalent_ui/api/v1/database/schema/electron.py +++ b/covalent_ui/api/v1/database/schema/electron.py @@ -91,8 +91,8 @@ class Electron(Base): # Short name describing the executor ("local", "dask", etc) executor = Column(Text) - # Name of the file containing the serialized executor data - executor_data_filename = Column(Text) + # JSONified executor attributes + executor_data = Column(Text) # name of the file containing the serialized output results_filename = Column(Text) diff --git a/covalent_ui/api/v1/database/schema/lattices.py b/covalent_ui/api/v1/database/schema/lattices.py index a21561c64c..06f752b2cc 100644 --- a/covalent_ui/api/v1/database/schema/lattices.py +++ b/covalent_ui/api/v1/database/schema/lattices.py @@ -88,14 +88,14 @@ class Lattice(Base): # Short name describing the executor ("local", "dask", etc) executor = Column(Text) - # Name of the file containing the serialized executor data - executor_data_filename = Column(Text) + # JSONified executor attributes + executor_data = Column(Text) # Short name describing the workflow executor ("local", "dask", etc) workflow_executor = Column(Text) - # Name of the file containing the serialized workflow executor data - workflow_executor_data_filename = Column(Text) + # JSONified executor attributes + workflow_executor_data = Column(Text) # Name of the file containing an error message for the workflow error_filename = Column(Text) @@ -113,7 +113,7 @@ class Lattice(Base): results_filename = Column(Text) # Name of the file containing the transport graph - transport_graph_filename = Column(Text) + # transport_graph_filename = Column(Text) # Name of the file containing the default electron dependencies deps_filename = Column(Text) diff --git a/covalent_ui/api/v1/models/lattices_model.py b/covalent_ui/api/v1/models/lattices_model.py index 780e33aaf4..4dace40ccc 100644 --- a/covalent_ui/api/v1/models/lattices_model.py +++ b/covalent_ui/api/v1/models/lattices_model.py @@ -127,4 +127,6 @@ class LatticeFileOutput(str, Enum): EXECUTOR = "executor" WORKFLOW_EXECUTOR = "workflow_executor" FUNCTION = "function" - TRANSPORT_GRAPH = "transport_graph" + + +# TRANSPORT_GRAPH = "transport_graph" diff --git a/covalent_ui/api/v1/routes/end_points/electron_routes.py b/covalent_ui/api/v1/routes/end_points/electron_routes.py index e7ee36e6e3..c138195a0d 100644 --- a/covalent_ui/api/v1/routes/end_points/electron_routes.py +++ b/covalent_ui/api/v1/routes/end_points/electron_routes.py @@ -20,12 +20,15 @@ """Electrons Route""" +import json import uuid from fastapi import APIRouter, HTTPException from sqlalchemy.orm import Session -from covalent._results_manager.results_manager import get_result +from covalent._shared_files.defaults import WAIT_EDGE_NAME +from covalent_dispatcher._core.data_modules import graph as core_graph +from covalent_dispatcher._dal.result import get_result_object from covalent_ui.api.v1.data_layer.electron_dal import Electrons from covalent_ui.api.v1.database.config.db import engine from covalent_ui.api.v1.models.electrons_model import ( @@ -78,6 +81,41 @@ def get_electron_details(dispatch_id: uuid.UUID, electron_id: int): ) +def _get_abstract_task_inputs(dispatch_id: str, node_id: int) -> dict: + """Return placeholders for the required inputs for a task execution. + + Args: + dispatch_id: id of the current dispatch + node_id: Node id of this task in the transport graph. + node_name: Name of the node. + + Returns: inputs: Input dictionary to be passed to the task with + `node_id` placeholders for args, kwargs. These are to be + resolved to their values later. + """ + + abstract_task_input = {"args": [], "kwargs": {}} + + in_edges = core_graph.get_incoming_edges(dispatch_id, node_id) + for edge in in_edges: + parent = edge["source"] + + d = edge["attrs"] + + if d["edge_name"] != WAIT_EDGE_NAME: + if d["param_type"] == "arg": + abstract_task_input["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_task_input["kwargs"][key] = parent + + sorted_args = sorted(abstract_task_input["args"], key=lambda x: x[1]) + abstract_task_input["args"] = [x[0] for x in sorted_args] + + return abstract_task_input + + +# Domain: data def get_electron_inputs(dispatch_id: uuid.UUID, electron_id: int) -> str: """ Get Electron Inputs @@ -87,17 +125,30 @@ def get_electron_inputs(dispatch_id: uuid.UUID, electron_id: int) -> str: Returns: Returns the inputs data from Result object """ - from covalent_dispatcher._core.execution import _get_task_inputs as get_task_inputs - result_object = get_result(dispatch_id=str(dispatch_id), wait=False) + abstract_inputs = _get_abstract_task_inputs(dispatch_id=str(dispatch_id), node_id=electron_id) + + # Resolve node ids to object strings + input_assets = {"args": [], "kwargs": {}} with Session(engine) as session: - electron = Electrons(session) - result = electron.get_electrons_id(dispatch_id, electron_id) - inputs = get_task_inputs( - node_id=electron_id, node_name=result.name, result_object=result_object - ) - return validate_data(inputs) + result_object = get_result_object(str(dispatch_id), bare=True) + tg = result_object.lattice.transport_graph + for arg in abstract_inputs["args"]: + node = tg.get_node(node_id=arg, session=session) + asset = node.get_asset(key="output", session=session) + input_assets["args"].append(asset) + for k, v in abstract_inputs["kwargs"].items(): + node = tg.get_node(node_id=v, session=session) + asset = node.get_asset(key="output", session=session) + input_assets["kwargs"][k] = asset + + # For now we load the picklefile from the object store into memory, but once + # TransportableObjects are no longer pickled we will be + # able to load the byte range for the object string. + input_args = [asset.load_data() for asset in input_assets["args"]] + input_kwargs = {k: asset.load_data() for k, asset in input_assets["kwargs"].items()} + return validate_data({"args": input_args, "kwargs": input_kwargs}) @routes.get("/{dispatch_id}/electron/{electron_id}/details/{name}") @@ -127,31 +178,32 @@ def get_electron_file(dispatch_id: uuid.UUID, electron_id: int, name: ElectronFi response = handler.read_from_text(result["function_string_filename"]) return ElectronFileResponse(data=response) elif name == "function": - response, python_object = handler.read_from_pickle(result["function_filename"]) + response, python_object = handler.read_from_serialized(result["function_filename"]) return ElectronFileResponse(data=response, python_object=python_object) elif name == "executor": executor_name = result["executor"] - executor_data = handler.read_from_pickle(result["executor_data_filename"]) + executor_data = json.loads(result["executor_data"]) + # executor_data = handler.read_from_serialized(result["executor_data_filename"]) return ElectronExecutorResponse( executor_name=executor_name, executor_details=executor_data ) elif name == "result": - response, python_object = handler.read_from_pickle(result["results_filename"]) + response, python_object = handler.read_from_serialized(result["results_filename"]) return ElectronFileResponse(data=str(response), python_object=python_object) elif name == "value": - response = handler.read_from_pickle(result["value_filename"]) + response = handler.read_from_serialized(result["value_filename"]) return ElectronFileResponse(data=str(response)) elif name == "stdout": response = handler.read_from_text(result["stdout_filename"]) return ElectronFileResponse(data=response) elif name == "deps": - response = handler.read_from_pickle(result["deps_filename"]) + response = handler.read_from_serialized(result["deps_filename"]) return ElectronFileResponse(data=response) elif name == "call_before": - response = handler.read_from_pickle(result["call_before_filename"]) + response = handler.read_from_serialized(result["call_before_filename"]) return ElectronFileResponse(data=response) elif name == "call_after": - response = handler.read_from_pickle(result["call_after_filename"]) + response = handler.read_from_serialized(result["call_after_filename"]) return ElectronFileResponse(data=response) elif name == "error": # Error and stderr won't be both populated if `error` diff --git a/covalent_ui/api/v1/routes/end_points/lattice_route.py b/covalent_ui/api/v1/routes/end_points/lattice_route.py index 3becb0b53a..8483186cde 100644 --- a/covalent_ui/api/v1/routes/end_points/lattice_route.py +++ b/covalent_ui/api/v1/routes/end_points/lattice_route.py @@ -20,6 +20,7 @@ """Lattice route""" +import json import uuid from typing import Optional @@ -99,27 +100,31 @@ def get_lattice_files(dispatch_id: uuid.UUID, name: LatticeFileOutput): if lattice_data is not None: handler = FileHandler(lattice_data["directory"]) if name == "result": - response, python_object = handler.read_from_pickle( + response, python_object = handler.read_from_serialized( lattice_data["results_filename"] ) return LatticeFileResponse(data=str(response), python_object=python_object) if name == "inputs": - response, python_object = handler.read_from_pickle(lattice_data["inputs_filename"]) + response, python_object = handler.read_from_serialized( + lattice_data["inputs_filename"] + ) return LatticeFileResponse(data=response, python_object=python_object) elif name == "function_string": response = handler.read_from_text(lattice_data["function_string_filename"]) return LatticeFileResponse(data=response) elif name == "executor": executor_name = lattice_data["executor"] - executor_data = handler.read_from_pickle(lattice_data["executor_data_filename"]) + executor_data = json.loads(lattice_data["executor_data"]) + # executor_data = handler.read_from_serialized(lattice_data["executor_data_filename"]) return LatticeExecutorResponse( executor_name=executor_name, executor_details=executor_data ) elif name == "workflow_executor": executor_name = lattice_data["workflow_executor"] - executor_data = handler.read_from_pickle( - lattice_data["workflow_executor_data_filename"] - ) + executor_data = json.loads(lattice_data["workflow_executor_data"]) + # executor_data = handler.read_from_serialized( + # lattice_data["workflow_executor_data_filename"] + # ) return LatticeWorkflowExecutorResponse( workflow_executor_name=executor_name, workflow_executor_details=executor_data ) @@ -127,13 +132,13 @@ def get_lattice_files(dispatch_id: uuid.UUID, name: LatticeFileOutput): response = handler.read_from_text(lattice_data["error_filename"]) return LatticeFileResponse(data=response) elif name == "function": - response, python_object = handler.read_from_pickle( + response, python_object = handler.read_from_serialized( lattice_data["function_filename"] ) return LatticeFileResponse(data=response, python_object=python_object) - elif name == "transport_graph": - response = handler.read_from_pickle(lattice_data["transport_graph_filename"]) - return LatticeFileResponse(data=response) + # elif name == "transport_graph": + # response = handler.read_from_pickle(lattice_data["transport_graph_filename"]) + # return LatticeFileResponse(data=response) else: return LatticeFileResponse(data=None) else: diff --git a/covalent_ui/api/v1/routes/routes.py b/covalent_ui/api/v1/routes/routes.py index 1defd768c3..d9b4de8f96 100644 --- a/covalent_ui/api/v1/routes/routes.py +++ b/covalent_ui/api/v1/routes/routes.py @@ -22,7 +22,7 @@ from fastapi import APIRouter -from covalent_dispatcher._service import app +from covalent_dispatcher._service import app, assets, runnersvc from covalent_dispatcher._triggers_app.app import router as tr_router from covalent_ui.api.v1.routes.end_points import ( electron_routes, @@ -43,6 +43,7 @@ routes.include_router(electron_routes.routes, prefix=dispatch_prefix, tags=["Electrons"]) routes.include_router(settings_routes.routes, prefix="/api/v1", tags=["Settings"]) routes.include_router(logs_route.routes, prefix="/api/v1/logs", tags=["Logs"]) -routes.include_router(app.router, prefix="/api", tags=["dispatcher"]) -routes.include_router(app.router, prefix="/api", tags=["dispatcher"]) routes.include_router(tr_router, prefix="/api", tags=["Triggers"]) +routes.include_router(app.router, prefix="/api/v1", tags=["Dispatcher"]) +routes.include_router(assets.router, prefix="/api/v1", tags=["Assets"]) +routes.include_router(runnersvc.router, prefix="/api/v1", tags=["Runner"]) diff --git a/covalent_ui/api/v1/utils/file_handle.py b/covalent_ui/api/v1/utils/file_handle.py index f739fa2f9b..7fb2dcf96a 100644 --- a/covalent_ui/api/v1/utils/file_handle.py +++ b/covalent_ui/api/v1/utils/file_handle.py @@ -26,6 +26,7 @@ import cloudpickle as pickle from covalent._workflow.transport import TransportableObject, _TransportGraph +from covalent_dispatcher._dal.asset import local_store def transportable_object(obj): @@ -107,6 +108,14 @@ def read_from_pickle(self, path): except Exception as e: return None + def read_from_serialized(self, path): + """Return data from serialized object""" + try: + deserialized_obj = local_store.load_file(self.location, path) + return validate_data(deserialized_obj) + except Exception as e: + return None + def read_from_text(self, path): """Return data from text file""" try: diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_dispatcher_tests/__init__.py b/tests/covalent_dispatcher_tests/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_dispatcher_tests/__init__.py +++ b/tests/covalent_dispatcher_tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_dispatcher_tests/_cli/__init__.py b/tests/covalent_dispatcher_tests/_cli/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_dispatcher_tests/_cli/__init__.py +++ b/tests/covalent_dispatcher_tests/_cli/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_dispatcher_tests/_core/__init__.py b/tests/covalent_dispatcher_tests/_core/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_dispatcher_tests/_core/__init__.py +++ b/tests/covalent_dispatcher_tests/_core/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_dispatcher_tests/_core/data_manager_test.py b/tests/covalent_dispatcher_tests/_core/data_manager_test.py index 0fdf996819..c55c97d111 100644 --- a/tests/covalent_dispatcher_tests/_core/data_manager_test.py +++ b/tests/covalent_dispatcher_tests/_core/data_manager_test.py @@ -23,34 +23,26 @@ """ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest import covalent as ct from covalent._results_manager import Result -from covalent._shared_files.defaults import sublattice_prefix from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.lattice import Lattice from covalent_dispatcher._core.data_manager import ( - _dispatch_status_queues, - _get_result_object_from_new_lattice, - _get_result_object_from_old_result, - _handle_built_sublattice, - _register_result_object, - _registered_dispatches, + ResultSchema, + _legacy_sublattice_dispatch_helper, + _make_sublattice_dispatch, + _redirect_lattice, _update_parent_electron, + ensure_dispatch, finalize_dispatch, - generate_node_result, get_result_object, - get_status_queue, - initialize_result_object, - make_derived_dispatch, make_dispatch, - make_sublattice_dispatch, persist_result, update_node_result, - upsert_lattice_data, ) from covalent_dispatcher._db.datastore import DataStore @@ -91,415 +83,237 @@ def pipeline(x): return result_object +@pytest.mark.parametrize( + "node_status,node_type,output_status,sub_id", + [ + (Result.COMPLETED, "function", Result.COMPLETED, ""), + (Result.FAILED, "function", Result.FAILED, ""), + (Result.CANCELLED, "function", Result.CANCELLED, ""), + (Result.COMPLETED, "sublattice", RESULT_STATUS.DISPATCHING, ""), + (Result.COMPLETED, "sublattice", RESULT_STATUS.COMPLETED, "asdf"), + (Result.FAILED, "sublattice", Result.FAILED, ""), + (Result.CANCELLED, "sublattice", Result.CANCELLED, ""), + ], +) @pytest.mark.asyncio -async def test_handle_built_sublattice(mocker): - """Test the handle_built_sublattice function.""" - - get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value="mock-result" - ) - make_sublattice_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_sublattice_dispatch", - return_value="mock-sub-dispatch-id", - ) - mock_node_result = generate_node_result( - node_id=0, - node_name="mock_node_name", - status=RESULT_STATUS.COMPLETED, - ) - - await _handle_built_sublattice("mock-dispatch-id", mock_node_result) - get_result_object_mock.assert_called_with("mock-dispatch-id") - make_sublattice_dispatch_mock.assert_called_with("mock-result", mock_node_result) - assert mock_node_result["status"] == RESULT_STATUS.DISPATCHING_SUBLATTICE - assert mock_node_result["start_time"] is not None - assert mock_node_result["end_time"] is None - assert mock_node_result["sub_dispatch_id"] == "mock-sub-dispatch-id" - +async def test_update_node_result(mocker, node_status, node_type, output_status, sub_id): + """Check that update_node_result pushes the correct status updates""" -@pytest.mark.asyncio -async def test_handle_built_sublattice_exception(mocker): - """Test the handle_built_sublattice function exception case.""" + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result" - get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", side_effect=Exception - ) - make_sublattice_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_sublattice_dispatch", - return_value="mock-sub-dispatch-id", - ) - mock_node_result = generate_node_result( - node_id=0, - node_name="mock_node_name", - status=RESULT_STATUS.COMPLETED, + node_result = {"node_id": 0, "status": node_status} + mock_update_node = mocker.patch( + "covalent_dispatcher._dal.result.Result._update_node", return_value=True ) + node_info = {"type": node_type, "sub_dispatch_id": sub_id, "status": Result.NEW_OBJ} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) - await _handle_built_sublattice("mock-dispatch-id", mock_node_result) - mock_node_result["error"] - get_result_object_mock.assert_called_with("mock-dispatch-id") - make_sublattice_dispatch_mock.assert_not_called() - assert mock_node_result["status"] == RESULT_STATUS.FAILED - assert "exception" in mock_node_result["error"].lower() - - -def test_initialize_result_object(mocker, test_db): - """Test the `initialize_result_object` function""" - - @ct.electron - def task(x): - return x - - @ct.lattice - def workflow(x): - return task(x) - - workflow.build_graph(1) - json_lattice = workflow.serialize_to_json() - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", return_value=test_db) - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", return_value=test_db) - result_object = get_mock_result() + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) - mock_persist = mocker.patch("covalent_dispatcher._db.update.persist") + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, + ) - sub_result_object = initialize_result_object( - json_lattice=json_lattice, parent_result_object=result_object, parent_electron_id=5 + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + return_value=sub_id, ) - mock_persist.assert_called_with(sub_result_object, electron_id=5) - assert sub_result_object._root_dispatch_id == result_object.dispatch_id + await update_node_result(result_object.dispatch_id, node_result) + detail = {"sub_dispatch_id": sub_id} if sub_id else {} + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, output_status, detail) + + if node_status == Result.COMPLETED and node_type == "sublattice" and not sub_id: + mock_make_dispatch.assert_awaited() + else: + mock_make_dispatch.assert_not_awaited() @pytest.mark.parametrize( - "node_name, node_status, sub_dispatch_id, detail", + "node_status,old_status,valid_update", [ - ( - f"{sublattice_prefix}workflow", - RESULT_STATUS.COMPLETED, - "mock-sub-dispatch-id", - {"sub_dispatch_id": "mock-sub-dispatch-id"}, - ), - (f"{sublattice_prefix}workflow", RESULT_STATUS.COMPLETED, None, {}), - ("mock-node-name", RESULT_STATUS.COMPLETED, None, {}), - ("mock-node-name", RESULT_STATUS.FAILED, None, {}), - ("mock-node-name", RESULT_STATUS.CANCELLED, None, {}), + (Result.COMPLETED, Result.RUNNING, True), + (Result.COMPLETED, Result.COMPLETED, False), + (Result.FAILED, Result.COMPLETED, False), ], ) @pytest.mark.asyncio -async def test_update_node_result(mocker, node_name, node_status, sub_dispatch_id, detail): +async def test_update_node_result_filters_illegal_updates( + mocker, node_status, old_status, valid_update +): """Check that update_node_result pushes the correct status updates""" - status_queue = AsyncMock() + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_filters_illegal_updates" + result_object._update_node = MagicMock(return_value=valid_update) + node_result = {"node_id": 0, "status": node_status} + node_info = {"type": "function", "sub_dispatch_id": "", "status": old_status} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) - result_object = get_mock_result() - mock_update_node = mocker.patch("covalent_dispatcher._db.update._node") - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", ) - handle_built_sublattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._handle_built_sublattice" + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, ) - node_result = { - "node_id": 0, - "node_name": node_name, - "status": node_status, - "sub_dispatch_id": sub_dispatch_id, - } - await update_node_result(result_object, node_result) + mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + ) - status_queue.put.assert_awaited_with((0, node_status, detail)) - mock_update_node.assert_called_with(result_object, **node_result) + await update_node_result(result_object.dispatch_id, node_result) - if ( - node_status == RESULT_STATUS.COMPLETED - and sub_dispatch_id is None - and node_name.startswith(sublattice_prefix) - ): - handle_built_sublattice_mock.assert_called_with(result_object.dispatch_id, node_result) + if not valid_update: + mock_notify.assert_not_awaited() else: - handle_built_sublattice_mock.assert_not_called() + mock_notify.assert_awaited() @pytest.mark.asyncio -async def test_update_node_result_handles_db_exceptions(mocker): - """Check that update_node_result handles db write failures""" +async def test_update_node_result_handles_keyerrors(mocker): + """Check that update_node_result handles invalid dispatch id or node id""" - status_queue = AsyncMock() + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_keyerrors" + node_result = {"node_id": -5, "status": RESULT_STATUS.COMPLETED} + mock_update_node = mocker.patch("covalent_dispatcher._dal.result.Result._update_node") + node_info = {"type": "function", "sub_dispatch_id": "", "status": RESULT_STATUS.RUNNING} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", side_effect=KeyError()) - result_object = get_mock_result() - mock_update_node = mocker.patch( - "covalent_dispatcher._db.update._node", side_effect=RuntimeError() + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", ) - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - node_result = { - "node_id": 0, - "node_name": "mock_node_name", - "status": RESULT_STATUS.COMPLETED, - "sub_dispatch_id": None, - } - await update_node_result(result_object, node_result) - status_queue.put.assert_awaited_with((0, RESULT_STATUS.FAILED, {})) + await update_node_result(result_object.dispatch_id, node_result) - -@pytest.mark.asyncio -async def test_make_dispatch(mocker): - res = get_mock_result() - mock_init_result = mocker.patch( - "covalent_dispatcher._core.data_manager.initialize_result_object", return_value=res - ) - mock_register = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object", return_value=res - ) - json_lattice = '{"workflow_function": "asdf"}' - dispatch_id = await make_dispatch(json_lattice) - assert dispatch_id == res.dispatch_id - mock_register.assert_called_with(res) + mock_notify.assert_not_awaited() @pytest.mark.asyncio -async def test_make_sublattice_dispatch(mocker): - """Test the make sublattice dispatch method.""" +async def test_update_node_result_handles_subl_exceptions(mocker): + """Check that update_node_result pushes the correct status updates""" - mock_result_object = get_mock_result() - output_mock = MagicMock() - mock_node_result = {"node_id": 0, "output": output_mock} - load_electron_record_mock = mocker.patch( - "covalent_dispatcher._db.load.electron_record", return_value={"id": "mock-electron-id"} - ) - make_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_dispatch", return_value="mock-dispatch-id" - ) + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_subl_exception" - res = await make_sublattice_dispatch(mock_result_object, mock_node_result) - assert res == "mock-dispatch-id" - load_electron_record_mock.assert_called_with( - mock_result_object.dispatch_id, mock_node_result["node_id"] + node_type = "sublattice" + sub_id = "" + node_result = {"node_id": 0, "status": Result.COMPLETED} + mock_update_node = mocker.patch("covalent_dispatcher._dal.result.Result._update_node") + node_info = {"type": node_type, "sub_dispatch_id": sub_id, "status": Result.NEW_OBJ} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", ) - make_dispatch_mock.assert_called_with( - output_mock.object_string, mock_result_object, "mock-electron-id" - ) - -@pytest.mark.parametrize("reuse", [True, False]) -def test_get_result_object_from_new_lattice(mocker, reuse): - """Test the get result object from new lattice json function.""" - lattice_mock = mocker.patch("covalent_dispatcher._core.data_manager.Lattice") - result_object_mock = mocker.patch("covalent_dispatcher._core.data_manager.Result") - transport_graph_ops_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.TransportGraphOps" - ) - old_result_mock = MagicMock() - res = _get_result_object_from_new_lattice( - json_lattice="mock-lattice", - old_result_object=old_result_mock, - reuse_previous_results=reuse, + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, ) - assert res == result_object_mock.return_value - lattice_mock.deserialize_from_json.assert_called_with("mock-lattice") - result_object_mock()._initialize_nodes.assert_called_with() - if reuse: - transport_graph_ops_mock().get_reusable_nodes.assert_called_with( - result_object_mock().lattice.transport_graph - ) - transport_graph_ops_mock().copy_nodes_from.assert_called_once_with( - old_result_mock.lattice.transport_graph, - transport_graph_ops_mock().get_reusable_nodes.return_value, - ) - - else: - transport_graph_ops_mock().get_reusable_nodes.assert_not_called() - transport_graph_ops_mock().copy_nodes_from.assert_not_called() - - -@pytest.mark.parametrize("reuse", [True, False]) -def test_get_result_object_from_old_result(mocker, reuse): - """Test the get result object from old result function.""" - result_object_mock = mocker.patch("covalent_dispatcher._core.data_manager.Result") - old_result_mock = MagicMock() - res = _get_result_object_from_old_result( - old_result_object=old_result_mock, - reuse_previous_results=reuse, + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + side_effect=RuntimeError(), ) - assert res == result_object_mock.return_value - if reuse: - result_object_mock()._initialize_nodes.assert_not_called() - else: - result_object_mock()._initialize_nodes.assert_called_with() + mocker.patch("traceback.TracebackException.from_exception", return_value="error") - assert res._num_nodes == old_result_mock._num_nodes + await update_node_result(result_object.dispatch_id, node_result) + output_status = Result.FAILED + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, output_status, {}) + mock_make_dispatch.assert_awaited() -@pytest.mark.parametrize("reuse", [True, False]) -def test_make_derived_dispatch_from_lattice(mocker, reuse): - """Test the make derived dispatch function.""" - - def mock_func(): - pass +@pytest.mark.asyncio +async def test_update_node_result_handles_db_exceptions(mocker): + """Check that update_node_result handles db write failures""" - mock_old_result = MagicMock() - mock_new_result = MagicMock() - mock_new_result.dispatch_id = "mock-redispatch-id" - mock_new_result.lattice.transport_graph._graph.nodes = ["mock-nodes"] - load_get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.load", return_value=mock_old_result - ) - get_result_object_from_new_lattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_new_lattice", - return_value=mock_new_result, - ) - get_result_object_from_old_result_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_old_result" - ) - update_mock = mocker.patch("covalent_dispatcher._core.data_manager.update") - register_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object" - ) - mock_electron_updates = {"mock-electron-id": mock_func} - redispatch_id = make_derived_dispatch( - parent_dispatch_id="mock-dispatch-id", - json_lattice="mock-json-lattice", - electron_updates=mock_electron_updates, - reuse_previous_results=reuse, - ) - load_get_result_object_mock.called_once_with("mock-dispatch-id", wait=reuse) - get_result_object_from_new_lattice_mock.called_once_with( - "mock-json-lattice", mock_old_result, reuse + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_db_exceptions" + result_object._update_node = MagicMock(side_effect=RuntimeError()) + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, ) - get_result_object_from_old_result_mock.assert_not_called() - mock_new_result.lattice.transport_graph.apply_electron_updates.assert_called_once_with( - mock_electron_updates + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", ) - update_mock().persist.called_once_with(mock_new_result) - register_result_object_mock.assert_called_once_with(mock_new_result) - assert redispatch_id == "mock-redispatch-id" - assert mock_new_result.lattice.transport_graph.dirty_nodes == ["mock-nodes"] + node_result = {"node_id": 0, "status": Result.COMPLETED} + await update_node_result(result_object.dispatch_id, node_result) -@pytest.mark.parametrize("reuse", [True, False]) -def test_make_derived_dispatch_from_old_result(mocker, reuse): - """Test the make derived dispatch function.""" - mock_old_result = MagicMock() - mock_new_result = MagicMock() - mock_new_result.dispatch_id = "mock-redispatch-id" - mock_new_result.lattice.transport_graph._graph.nodes = ["mock-nodes"] - load_get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.load", return_value=mock_old_result - ) - get_result_object_from_new_lattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_new_lattice", - ) - get_result_object_from_old_result_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_old_result", - return_value=mock_new_result, - ) - update_mock = mocker.patch("covalent_dispatcher._core.data_manager.update") - register_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object" - ) - redispatch_id = make_derived_dispatch( - parent_dispatch_id="mock-dispatch-id", - reuse_previous_results=reuse, + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, Result.FAILED, {}) + + +@pytest.mark.asyncio +async def test_make_dispatch(mocker): + res = MagicMock() + dispatch_id = "test_make_dispatch" + mock_resubmit_lattice = mocker.patch( + "covalent_dispatcher._core.data_manager._redirect_lattice", return_value=dispatch_id ) - load_get_result_object_mock.called_once_with("mock-dispatch-id", wait=reuse) - get_result_object_from_new_lattice_mock.assert_not_called() - get_result_object_from_old_result_mock.called_once_with(mock_old_result, reuse) - mock_new_result.lattice.transport_graph.apply_electron_updates.assert_called_once_with({}) - update_mock().persist.called_once_with(mock_new_result) - register_result_object_mock.assert_called_once_with(mock_new_result) - assert redispatch_id == "mock-redispatch-id" - assert mock_new_result.lattice.transport_graph.dirty_nodes == ["mock-nodes"] + json_lattice = '{"workflow_function": "asdf"}' + assert dispatch_id == await make_dispatch(json_lattice) def test_get_result_object(mocker): - """ - Test get result object - """ - result_object = get_mock_result() - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object - assert get_result_object(dispatch_id) is result_object - del _registered_dispatches[dispatch_id] - + result_object = MagicMock() + result_object.dispatch_id = "dispatch_1" + mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object_from_db", + return_value=result_object, + ) -def test_register_result_object(mocker): - """ - Test registering a result object - """ - result_object = get_mock_result() dispatch_id = result_object.dispatch_id - _register_result_object(result_object) - assert _registered_dispatches[dispatch_id] is result_object - del _registered_dispatches[dispatch_id] + assert get_result_object(dispatch_id) is result_object -def test_unregister_result_object(mocker): - """ - Test unregistering a result object from lattice - """ - result_object = get_mock_result() - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object +@pytest.mark.parametrize("stateless", [False, True]) +def test_unregister_result_object(mocker, stateless): + dispatch_id = "test_unregister_result_object" finalize_dispatch(dispatch_id) - assert dispatch_id not in _registered_dispatches - - -def test_get_status_queue(): - """ - Test querying the dispatch status from the queue - """ - import asyncio - - dispatch_id = "dispatch" - q = asyncio.Queue() - _dispatch_status_queues[dispatch_id] = q - assert get_status_queue(dispatch_id) is q @pytest.mark.asyncio async def test_persist_result(mocker): - """ - Test persisting the result object - """ - result_object = get_mock_result() - - mock_get_result = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object - ) + dispatch_id = "test_persist_result" mock_update_parent = mocker.patch( "covalent_dispatcher._core.data_manager._update_parent_electron" ) - mock_persist = mocker.patch("covalent_dispatcher._core.data_manager.update.persist") - await persist_result(result_object.dispatch_id) - mock_update_parent.assert_awaited_with(result_object) - mock_persist.assert_called_with(result_object) + await persist_result(dispatch_id) + mock_update_parent.assert_awaited_with(dispatch_id) @pytest.mark.parametrize( "sub_status,mapped_status", - [ - (RESULT_STATUS.COMPLETED, RESULT_STATUS.COMPLETED), - (RESULT_STATUS.POSTPROCESSING_FAILED, RESULT_STATUS.FAILED), - ], + [(Result.COMPLETED, Result.COMPLETED), (Result.POSTPROCESSING_FAILED, Result.FAILED)], ) @pytest.mark.asyncio async def test_update_parent_electron(mocker, sub_status, mapped_status): - """ - Test updating parent electron data - """ - parent_result_obj = get_mock_result() - sub_result_obj = get_mock_result() + import datetime + + mock_res = MagicMock() + mock_res.dispatch_id = "test_update_parent_electron" + parent_result_obj = MagicMock() + sub_result_obj = MagicMock() eid = 5 + + parent_result_obj.dispatch_id = mock_res.dispatch_id + parent_dispatch_id = (parent_result_obj.dispatch_id,) parent_node_id = 2 sub_result_obj._electron_id = eid - sub_result_obj._status = sub_status + sub_result_obj.status = sub_status sub_result_obj._result = 42 + sub_result_obj._error = "" + sub_result_obj._end_time = datetime.datetime.now() mock_node_result = { "node_id": parent_node_id, @@ -509,35 +323,155 @@ async def test_update_parent_electron(mocker, sub_status, mapped_status): "error": sub_result_obj._error, } - mocker.patch( + mock_gen_node_result = mocker.patch( "covalent_dispatcher._core.data_manager.generate_node_result", return_value=mock_node_result, ) mock_update_node = mocker.patch("covalent_dispatcher._core.data_manager.update_node_result") - mocker.patch( + mock_resolve_eid = mocker.patch( "covalent_dispatcher._core.data_manager.resolve_electron_id", return_value=(parent_dispatch_id, parent_node_id), ) mock_get_res = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=parent_result_obj + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=parent_result_obj, ) - load_mock = mocker.patch("covalent_dispatcher._core.data_manager.load") - load_mock.sublattice_dispatch_id.return_value = "mock-sub-dispatch-id" + + mock_get_res = mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", + return_value=parent_result_obj, + ) + await _update_parent_electron(sub_result_obj) mock_get_res.assert_called_with(parent_dispatch_id) - mock_update_node.assert_awaited_with(parent_result_obj, mock_node_result) + mock_update_node.assert_awaited_with(parent_result_obj.dispatch_id, mock_node_result) + + +@pytest.mark.asyncio +async def test_make_sublattice_dispatch(mocker): + node_result = {"node_id": 0, "status": Result.COMPLETED} + output_json = "lattice_json" + + mock_node = MagicMock() + mock_node._electron_id = 5 + + mock_bg_output = MagicMock() + mock_bg_output.object_string = output_json + + mock_node.get_value = MagicMock(return_value=mock_bg_output) + + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = "mock_sublattice_dispatch" + + result_object = MagicMock() + result_object.dispatch_id = "dispatch" + result_object.lattice.transport_graph.get_node = MagicMock(return_value=mock_node) + mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", + return_value=result_object, + ) + mocker.patch("covalent._shared_files.schemas.result.ResultSchema.parse_raw") + mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer.import_manifest", + return_value=mock_manifest, + ) + + mock_make_dispatch = mocker.patch("covalent_dispatcher._core.data_manager.make_dispatch") + sub_dispatch_id = await _make_sublattice_dispatch(result_object.dispatch_id, node_result) + + # mock_make_dispatch.assert_awaited_with("lattice_json", result_object, mock_node._electron_id) + assert sub_dispatch_id == mock_manifest.metadata.dispatch_id + + +@pytest.mark.asyncio +async def test_make_monolithic_sublattice_dispatch(mocker): + """Check that JSON sublattices are handled correctly""" + + dispatch_id = "test_make_monolithic_sublattice_dispatch" + def _mock_helper(dispatch_id, node_result): + return ResultSchema.parse_raw("invalid_input") -def test_upsert_lattice_data(mocker): - """ - Test updating lattice data in database - """ - result_object = get_mock_result() mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch_helper", _mock_helper + ) + + json_lattice = "json_lattice" + parent_electron_id = 5 + mock_legacy_subl_helper = mocker.patch( + "covalent_dispatcher._core.data_manager._legacy_sublattice_dispatch_helper", + return_value=(json_lattice, parent_electron_id), + ) + sub_dispatch_id = "sub_dispatch" + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager.make_dispatch", return_value=sub_dispatch_id + ) + + assert sub_dispatch_id == await _make_sublattice_dispatch(dispatch_id, {}) + + mock_make_dispatch.assert_awaited_with(json_lattice, dispatch_id, parent_electron_id) + + +def test_legacy_sublattice_dispatch_helper(mocker): + dispatch_id = "test_legacy_sublattice_dispatch_helper" + res_obj = MagicMock() + bg_output = MagicMock() + bg_output.object_string = "json_sublattice" + parent_node = MagicMock() + parent_node._electron_id = 2 + parent_node.get_value = MagicMock(return_value=bg_output) + res_obj.lattice.transport_graph.get_node = MagicMock(return_value=parent_node) + node_result = {"node_id": 0} + + mocker.patch("covalent_dispatcher._core.data_manager.get_result_object", return_value=res_obj) + + assert _legacy_sublattice_dispatch_helper(dispatch_id, node_result) == ("json_sublattice", 2) + + +def test_redirect_lattice(mocker): + """Test redirecting JSON lattices to new DAL.""" + + dispatch_id = "test_redirect_lattice" + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = dispatch_id + mock_prepare_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.prepare_manifest", + return_value=mock_manifest, + ) + mock_import_manifest = mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer._import_manifest", + return_value=mock_manifest, + ) + + mock_pull = mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer._pull_assets", + ) + + mock_lat_deserialize = mocker.patch( + "covalent_dispatcher._core.data_manager.Lattice.deserialize_from_json" + ) + + json_lattice = "json_lattice" + + parent_dispatch_id = "parent_dispatch" + parent_electron_id = 3 + + assert ( + _redirect_lattice(json_lattice, parent_dispatch_id, parent_electron_id, None) + == dispatch_id + ) + + mock_import_manifest.assert_called_with(mock_manifest, parent_dispatch_id, parent_electron_id) + mock_pull.assert_called_with(mock_manifest) + + +@pytest.mark.asyncio +async def test_ensure_dispatch(mocker): + mock_ensure_run_once = mocker.patch( + "covalent_dispatcher._core.data_manager.SRVResult.ensure_run_once", + return_value=True, ) - mock_upsert_lattice = mocker.patch("covalent_dispatcher._db.upsert.lattice_data") - upsert_lattice_data(result_object.dispatch_id) - mock_upsert_lattice.assert_called_with(result_object) + assert await ensure_dispatch("test_ensure_dispatch") is True + mock_ensure_run_once.assert_called_with("test_ensure_dispatch") diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py new file mode 100644 index 0000000000..9a15773200 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py @@ -0,0 +1,169 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Tests for DB-backed Result""" + + +import os +import tempfile + +import pytest + +import covalent as ct +from covalent._results_manager import Result as SDKResult +from covalent._shared_files.schemas.asset import AssetUpdate +from covalent._workflow.lattice import Lattice as SDKLattice +from covalent_dispatcher._core.data_modules import asset_manager as am +from covalent_dispatcher._dal.result import Result, get_result_object +from covalent_dispatcher._db import update +from covalent_dispatcher._db.datastore import DataStore + +TEMP_RESULTS_DIR = os.environ.get("COVALENT_DATA_DIR") or ct.get_config("dispatcher.results_dir") + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> SDKResult: + """Construct a mock result object corresponding to a lattice.""" + + @ct.electron(executor="local") + def task(x): + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def workflow(x): + res1 = task(x) + res2 = task(res1) + return res2 + + workflow.build_graph(x=1) + received_workflow = SDKLattice.deserialize_from_json(workflow.serialize_to_json()) + result_object = SDKResult(received_workflow, "mock_dispatch") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> Result: + sdkres._initialize_nodes() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id) + + +@pytest.mark.asyncio +async def test_upload_asset_for_nodes(test_db, mocker): + sdkres = get_mock_result() + sdkres._initialize_nodes() + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + srvres = get_mock_srvresult(sdkres, test_db) + + srvres.lattice.transport_graph.set_node_value(0, "stdout", "Hello!\n") + srvres.lattice.transport_graph.set_node_value(2, "stdout", "Bye!\n") + + with tempfile.NamedTemporaryFile("w", delete=True, suffix=".txt") as temp: + dest_path_0 = temp.name + + with tempfile.NamedTemporaryFile("w", delete=True, suffix=".txt") as temp: + dest_path_2 = temp.name + + dest_uri_0 = os.path.join("file://", dest_path_0) + dest_uri_2 = os.path.join("file://", dest_path_2) + + await am.upload_asset_for_nodes(srvres.dispatch_id, "stdout", {0: dest_uri_0, 2: dest_uri_2}) + + with open(dest_path_0, "r") as f: + assert f.read() == "Hello!\n" + + with open(dest_path_2, "r") as f: + assert f.read() == "Bye!\n" + + os.unlink(dest_path_0) + os.unlink(dest_path_2) + + +@pytest.mark.asyncio +async def test_download_assets_for_node(test_db, mocker): + sdkres = get_mock_result() + sdkres._initialize_nodes() + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + mock_update_assets = mocker.patch("covalent_dispatcher._dal.electron.Electron.update_assets") + + srvres = get_mock_srvresult(sdkres, test_db) + + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt") as temp: + src_path_stdout = temp.name + temp.write("Hello!\n") + + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt") as temp: + src_path_stderr = temp.name + temp.write("Bye!\n") + + src_uri_stdout = os.path.join("file://", src_path_stdout) + src_uri_stderr = os.path.join("file://", src_path_stderr) + + assets = { + "output": { + "remote_uri": "", + }, + "stdout": {"remote_uri": src_uri_stdout, "size": None, "digest": "0af23"}, + "stderr": { + "remote_uri": src_uri_stderr, + }, + } + assets = {k: AssetUpdate(**v) for k, v in assets.items()} + + expected_update = { + "output": { + "remote_uri": "", + }, + "stdout": { + "remote_uri": src_uri_stdout, + "digest": "0af23", + }, + "stderr": { + "remote_uri": src_uri_stderr, + }, + } + await am.download_assets_for_node( + srvres.dispatch_id, + 0, + assets, + ) + + mock_update_assets.assert_called_with(expected_update) + assert srvres.lattice.transport_graph.get_node_value(0, "stdout") == "Hello!\n" + assert srvres.lattice.transport_graph.get_node_value(0, "stderr") == "Bye!\n" diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py new file mode 100644 index 0000000000..8c937ae113 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py @@ -0,0 +1,73 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the querying and updating dispatches +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import dispatch + + +@pytest.mark.asyncio +async def test_get(mocker): + dispatch_id = "test_get_incoming_edges" + + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj.get_values = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await dispatch.get(dispatch_id, keys=["status"]) + + +@pytest.mark.asyncio +async def test_get_incomplete_tasks(mocker): + dispatch_id = "test_get_node_successors" + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj._get_incomplete_nodes = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await dispatch.get_incomplete_tasks(dispatch_id) + + +@pytest.mark.asyncio +async def test_update(mocker): + dispatch_id = "test_update_dispatch" + mock_result_obj = MagicMock() + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + await dispatch.update(dispatch_id, {"status": "COMPLETED"}) + + mock_result_obj._update_dispatch.assert_called() diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py new file mode 100644 index 0000000000..f3db2c3a89 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py @@ -0,0 +1,98 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the graph querying functions +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import graph + + +@pytest.mark.asyncio +async def test_get_incoming_edges(mocker): + dispatch_id = "test_get_incoming_edges" + node_id = 0 + + mock_result_obj = MagicMock() + mock_return_val = [{"source": 1, "target": 0, "attrs": {"param_type": "arg"}}] + mock_result_obj.lattice.transport_graph.get_incoming_edges = MagicMock( + return_value=mock_return_val + ) + + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_return_val == await graph.get_incoming_edges(dispatch_id, node_id) + + +@pytest.mark.asyncio +async def test_get_node_successors(mocker): + dispatch_id = "test_get_node_successors" + node_id = 0 + + mock_result_obj = MagicMock() + mock_return_val = {"node_id": 0, "status": "NEW_OBJECT"} + mock_result_obj.lattice.transport_graph.get_successors = MagicMock( + return_value=mock_return_val + ) + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + assert mock_return_val == await graph.get_node_successors(dispatch_id, node_id) + + +@pytest.mark.asyncio +async def test_get_node_links(mocker): + dispatch_id = "test_get_node_links" + + mock_result_obj = MagicMock() + + mock_return_val = {"nodes": [0, 1], "links": [(1, 0, 0)]} + mocker.patch("networkx.readwrite.node_link_data", return_value=mock_return_val) + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_return_val == await graph.get_nodes_links(dispatch_id) + + +@pytest.mark.asyncio +async def test_get_nodes(mocker): + dispatch_id = "test_get_nodes" + mock_result_obj = MagicMock() + + g = MagicMock() + mock_result_obj.lattice.transport_graph.get_internal_graph_copy = MagicMock(return_value=g) + g.nodes = [1, 2, 3] + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert [1, 2, 3] == await graph.get_nodes(dispatch_id) diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py new file mode 100644 index 0000000000..c36101f396 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py @@ -0,0 +1,124 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Unit tests for the importer entry point""" + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules.importer import ( + _copy_assets, + import_derived_manifest, + import_manifest, +) + + +@pytest.mark.asyncio +async def test_import_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = None + + mock_srvres = MagicMock() + mocker.patch( + "covalent_dispatcher._dal.result.Result.from_dispatch_id", return_value=mock_srvres + ) + + mock_asset = MagicMock() + mock_asset.remote_uri = "s3://mybucket/object.pkl" + + mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_result", return_value=mock_manifest + ) + + mock_assets = {"lattice": [mock_asset], "nodes": [mock_asset]} + mocker.patch( + "covalent_dispatcher._core.data_modules.importer._get_all_assets", return_value=mock_assets + ) + + return_manifest = await import_manifest(mock_manifest, None, None) + + assert return_manifest.metadata.dispatch_id is not None + + +@pytest.mark.asyncio +async def test_import_sublattice_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = None + + mock_parent_res = MagicMock() + mock_parent_res.root_dispatch_id = "parent_dispatch_id" + + mock_asset = MagicMock() + mock_asset.remote_uri = "s3://mybucket/object.pkl" + + mock_srvres = MagicMock() + mocker.patch( + "covalent_dispatcher._dal.result.Result.from_dispatch_id", return_value=mock_parent_res + ) + + mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_result", return_value=mock_manifest + ) + + mock_assets = {"lattice": [MagicMock()], "nodes": [MagicMock()]} + + return_manifest = await import_manifest(mock_manifest, "parent_dispatch_id", None) + + assert return_manifest.metadata.dispatch_id is not None + assert return_manifest.metadata.root_dispatch_id == "parent_dispatch_id" + + +@pytest.mark.asyncio +async def test_import_derived_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = "test_import_derived_manifest" + + mock_import_manifest = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._import_manifest", + ) + + mock_copy = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._copy_assets", + ) + + mock_handle_redispatch = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.handle_redispatch", + return_value=(mock_manifest, []), + ) + + mock_pull = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._pull_assets", + ) + + mock_manifest = {} + await import_derived_manifest(mock_manifest, "parent_dispatch", True) + + mock_import_manifest.assert_called() + mock_pull.assert_called() + mock_handle_redispatch.assert_called() + mock_copy.assert_called_with([]) + + +def test_copy_assets(mocker): + mock_copy = mocker.patch("covalent_dispatcher._core.data_modules.importer.copy_asset") + + _copy_assets([("src", "dest")]) + mock_copy.assert_called_with("src", "dest") diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py index e363dc0219..243150aa35 100644 --- a/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py +++ b/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py @@ -25,8 +25,8 @@ from covalent_dispatcher._core.data_modules.job_manager import ( get_jobs_metadata, set_cancel_requested, - set_cancel_result, set_job_handle, + set_job_status, ) @@ -104,9 +104,7 @@ async def test_set_job_handle(mocker): mock_update.assert_called_with([{"job_id": 1, "job_handle": "12356"}]) -@pytest.mark.asyncio -@pytest.mark.parametrize("cancel_requested", [True, False]) -async def test_set_cancel_result(cancel_requested, mocker): +async def test_set_job_status(mocker): """ Test requesting a task to be cancelled """ @@ -115,5 +113,5 @@ async def test_set_cancel_result(cancel_requested, mocker): mock_update = mocker.patch( "covalent_dispatcher._core.data_modules.job_manager.update_job_records" ) - await set_cancel_result("dispatch", 0, cancel_status=cancel_requested) - mock_update.assert_called_with([{"job_id": 1, "cancel_successful": cancel_requested}]) + await set_job_status("dispatch", 0, status="COMPLETED") + mock_update.assert_called_with([{"job_id": 1, "status": "COMPLEtED"}]) diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py new file mode 100644 index 0000000000..6ca7d3a802 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py @@ -0,0 +1,45 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the querying lattices +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import lattice + + +@pytest.mark.asyncio +async def test_get(mocker): + dispatch_id = "test_get" + + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj.lattice.get_values = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.lattice.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await lattice.get(dispatch_id, keys=["executor"]) diff --git a/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py new file mode 100644 index 0000000000..810c66410e --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py @@ -0,0 +1,326 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the core functionality of the dispatcher. +""" + + +from typing import Dict, List + +import pytest + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.dispatcher import ( + _get_abstract_task_inputs, + _get_initial_tasks_and_deps, + _handle_completed_node, +) +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import models, update +from covalent_dispatcher._db.datastore import DataStore + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + with test_db.session() as session: + record = session.query(models.Lattice).where(models.Lattice.id == 1).first() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id, bare=False) + + +@pytest.mark.asyncio +async def test_get_abstract_task_inputs(mocker, test_db): + """Test _get_abstract_task_inputs for both dicts and list parameter types""" + + @ct.electron + def list_task(arg: List): + return len(arg) + + @ct.electron + def dict_task(arg: Dict): + return len(arg) + + @ct.electron + def multivariable_task(x, y): + return x, y + + @ct.lattice + def list_workflow(arg): + return list_task(arg) + + @ct.lattice + def dict_workflow(arg): + return dict_task(arg) + + # 1 2 + # \ \ + # 0 3 + # / /\/ + # 4 5 + + @ct.electron + def identity(x): + return x + + @ct.lattice + def multivar_workflow(x, y): + electron_x = identity(x) + electron_y = identity(y) + res1 = multivariable_task(electron_x, electron_y) + res2 = multivariable_task(electron_y, electron_x) + res3 = multivariable_task(electron_y, electron_x) + res4 = multivariable_task(electron_x, electron_y) + return 1 + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + # list-type inputs + + # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 + list_workflow.build_graph([1, 2, 3]) + abstract_args = [2, 3, 4] + tg = list_workflow.transport_graph + + sdkres = Result(lattice=list_workflow, dispatch_id="list_input_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + dispatch_id = result_object.dispatch_id + + async def mock_get_incoming_edges(dispatch_id, node_id): + return result_object.lattice.transport_graph.get_incoming_edges(node_id) + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + abs_task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 1, tg.get_node_value(1, "name") + ) + + expected_inputs = {"args": abstract_args, "kwargs": {}} + + assert abs_task_inputs == expected_inputs + + # dict-type inputs + + # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 + dict_workflow.build_graph({"a": 1, "b": 2}) + abstract_args = {"a": 2, "b": 3} + tg = dict_workflow.transport_graph + + sdkres = Result(lattice=dict_workflow, dispatch_id="dict_input_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 1, tg.get_node_value(1, "name") + ) + expected_inputs = {"args": [], "kwargs": abstract_args} + + assert task_inputs == expected_inputs + + # Check arg order + multivar_workflow.build_graph(1, 2) + received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) + sdkres = Result(lattice=received_lattice, dispatch_id="arg_order_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + tg = received_lattice.transport_graph + + # Account for injected postprocess electron + assert list(tg._graph.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + tg.set_node_value(0, "output", ct.TransportableObject(1)) + tg.set_node_value(2, "output", ct.TransportableObject(2)) + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 4, tg.get_node_value(4, "name") + ) + assert task_inputs["args"] == [0, 2] + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 5, tg.get_node_value(5, "name") + ) + assert task_inputs["args"] == [2, 0] + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 6, tg.get_node_value(6, "name") + ) + assert task_inputs["args"] == [2, 0] + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 7, tg.get_node_value(7, "name") + ) + assert task_inputs["args"] == [0, 2] + + +@pytest.mark.asyncio +async def test_handle_completed_node(mocker, test_db): + """Unit test for completed node handler""" + + from covalent_dispatcher._core.dispatcher import _initialize_caches, _pending_parents + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + pending_parents = {} + sorted_task_groups = {} + sdkres = get_mock_result() + result_object = get_mock_srvresult(sdkres, test_db) + + async def get_node_successors(dispatch_id: str, node_id: int): + return result_object.lattice.transport_graph.get_successors(node_id, ["task_group_id"]) + + async def electron_get(dispatch_id, node_id, keys): + return {keys[0]: node_id} + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_node_successors", + get_node_successors, + ) + + mocker.patch( + "covalent_dispatcher._core.data_manager.electron.get", + electron_get, + ) + + # tg edges are (1, 0), (0, 2) + pending_parents[0] = 1 + pending_parents[1] = 0 + pending_parents[2] = 1 + sorted_task_groups[0] = [0] + sorted_task_groups[1] = [1] + sorted_task_groups[2] = [2] + + await _initialize_caches(result_object.dispatch_id, pending_parents, sorted_task_groups) + + node_result = {"node_id": 1, "status": Result.COMPLETED} + assert await _pending_parents.get_pending(result_object.dispatch_id, 0) == 1 + assert await _pending_parents.get_pending(result_object.dispatch_id, 1) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 2) == 1 + + next_nodes = await _handle_completed_node(result_object.dispatch_id, 1) + assert next_nodes == [0] + + assert await _pending_parents.get_pending(result_object.dispatch_id, 0) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 1) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 2) == 1 + + +@pytest.mark.asyncio +async def test_get_initial_tasks_and_deps(mocker, test_db): + """Test internal function for initializing status_queue and pending_parents""" + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + pending_parents = {} + + sdkres = get_mock_result() + result_object = get_mock_srvresult(sdkres, test_db) + dispatch_id = result_object.dispatch_id + + async def get_graph_nodes_links(dispatch_id: str) -> dict: + import networkx as nx + + """Return the internal transport graph in NX node-link form""" + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return nx.readwrite.node_link_data(g) + + mocker.patch( + "covalent_dispatcher._core.data_manager.graph.get_nodes_links", + side_effect=get_graph_nodes_links, + ) + + initial_nodes, pending_parents, sorted_task_groups = await _get_initial_tasks_and_deps( + dispatch_id + ) + + assert initial_nodes == [1] + + # Account for injected postprocess electron + assert pending_parents == {0: 1, 1: 0, 2: 1, 3: 3} + assert sorted_task_groups == {0: [0], 1: [1], 2: [2], 3: [3]} diff --git a/tests/covalent_dispatcher_tests/_core/dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py index 866f377a97..634fd59321 100644 --- a/tests/covalent_dispatcher_tests/_core/dispatcher_test.py +++ b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py @@ -23,26 +23,22 @@ """ -from typing import Dict, List from unittest.mock import AsyncMock, call -import cloudpickle as pickle import pytest -from mock import MagicMock import covalent as ct from covalent._results_manager import Result -from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.lattice import Lattice from covalent_dispatcher._core.dispatcher import ( - _get_abstract_task_inputs, - _get_initial_tasks_and_deps, + _clear_caches, + _finalize_dispatch, _handle_cancelled_node, - _handle_completed_node, + _handle_event, _handle_failed_node, - _plan_workflow, - _run_planned_workflow, - _submit_task, + _handle_node_status_update, + _submit_initial_tasks, + _submit_task_group, cancel_dispatch, run_dispatch, run_workflow, @@ -86,395 +82,602 @@ def pipeline(x): return result_object -def test_plan_workflow(): - """Test workflow planning method.""" +@pytest.mark.asyncio +async def test_handle_failed_node(mocker): + """Unit test for failed node handler""" + dispatch_id = "failed_dispatch" + await _handle_failed_node(dispatch_id, 1) - @ct.electron - def task(x): - return x - @ct.lattice - def workflow(x): - return task(x) +@pytest.mark.asyncio +async def test_handle_cancelled_node(mocker, test_db): + """Unit test for cancelled node handler""" + dispatch_id = "cancelled_dispatch" - workflow.metadata["schedule"] = True - received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = Result(received_workflow, "asdf") - _plan_workflow(result_object=result_object) + await _handle_cancelled_node(dispatch_id, 1) - # Updated transport graph post planning - updated_tg = pickle.loads(result_object.lattice.transport_graph.serialize(metadata_only=True)) - assert updated_tg["lattice_metadata"]["schedule"] +@pytest.mark.parametrize( + "wait,expected_status", [(True, Result.COMPLETED), (False, Result.RUNNING)] +) +@pytest.mark.asyncio +async def test_run_workflow_normal(mocker, wait, expected_status): + import asyncio + dispatch_id = "mock_dispatch" -def test_get_abstract_task_inputs(): - """Test _get_abstract_task_inputs for both dicts and list parameter types""" + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=True) - @ct.electron - def list_task(arg: List): - return len(arg) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.NEW_OBJ}, + ) + _futures = {dispatch_id: asyncio.Future()} + mocker.patch("covalent_dispatcher._core.dispatcher._futures", _futures) - @ct.electron - def dict_task(arg: Dict): - return len(arg) + async def mark_future_done(dispatch_id): + _futures[dispatch_id].set_result(Result.COMPLETED) + return Result.RUNNING - @ct.electron - def multivariable_task(x, y): - return x, y + mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_initial_tasks", + return_value=Result.RUNNING, + side_effect=mark_future_done, + ) - @ct.lattice - def list_workflow(arg): - return list_task(arg) + dispatch_status = await run_workflow(dispatch_id, wait) + assert dispatch_status == expected_status + if wait: + mock_unregister.assert_called_with(dispatch_id) - @ct.lattice - def dict_workflow(arg): - return dict_task(arg) - # 1 2 - # \ \ - # 0 3 - # / /\/ - # 4 5 +@pytest.mark.parametrize("wait", [True, False]) +@pytest.mark.asyncio +async def test_run_completed_workflow(mocker, wait): + import asyncio - @ct.electron - def identity(x): - return x + dispatch_id = "completed_dispatch" + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=False + ) - @ct.lattice - def multivar_workflow(x, y): - electron_x = identity(x) - electron_y = identity(y) - res1 = multivariable_task(electron_x, electron_y) - res2 = multivariable_task(electron_y, electron_x) - res3 = multivariable_task(electron_y, electron_x) - res4 = multivariable_task(electron_x, electron_y) - return 1 + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.COMPLETED}, + ) - # list-type inputs + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.COMPLETED}, + ) + mock_plan = mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") + dispatch_status = await run_workflow(dispatch_id, wait) - # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 - list_workflow.build_graph([1, 2, 3]) - abstract_args = [2, 3, 4] - tg = list_workflow.transport_graph + mock_unregister.assert_not_called() + assert dispatch_status == Result.COMPLETED - result_object = Result(lattice=list_workflow, dispatch_id="asdf") - abs_task_inputs = _get_abstract_task_inputs(1, tg.get_node_value(1, "name"), result_object) - expected_inputs = {"args": abstract_args, "kwargs": {}} +@pytest.mark.parametrize("wait", [True, False]) +@pytest.mark.asyncio +async def test_run_workflow_exception(mocker, wait): + import asyncio - assert abs_task_inputs == expected_inputs + dispatch_id = "mock_dispatch" - # dict-type inputs + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") + mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_initial_tasks", + side_effect=RuntimeError("Error"), + ) - # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 - dict_workflow.build_graph({"a": 1, "b": 2}) - abstract_args = {"a": 2, "b": 3} - tg = dict_workflow.transport_graph + mock_dispatch_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + ) + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=True) - result_object = Result(lattice=dict_workflow, dispatch_id="asdf") - task_inputs = _get_abstract_task_inputs(1, tg.get_node_value(1, "name"), result_object) - expected_inputs = {"args": [], "kwargs": abstract_args} + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.NEW_OBJ}, + ) - assert task_inputs == expected_inputs + status = await run_workflow(dispatch_id, wait) - # Check arg order - multivar_workflow.build_graph(1, 2) - received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) - result_object = Result(lattice=received_lattice, dispatch_id="asdf") - tg = received_lattice.transport_graph + assert status == Result.FAILED + mock_unregister.assert_called_with(dispatch_id) - assert list(tg._graph.nodes) == list(range(9)) - tg.set_node_value(0, "output", ct.TransportableObject(1)) - tg.set_node_value(2, "output", ct.TransportableObject(2)) - task_inputs = _get_abstract_task_inputs(4, tg.get_node_value(4, "name"), result_object) - assert task_inputs["args"] == [0, 2] +@pytest.mark.asyncio +async def test_run_dispatch(mocker): + dispatch_id = "test_dispatch" + mock_run = mocker.patch("covalent_dispatcher._core.dispatcher.run_workflow") + run_dispatch(dispatch_id) + mock_run.assert_called_with(dispatch_id) - task_inputs = _get_abstract_task_inputs(5, tg.get_node_value(5, "name"), result_object) - assert task_inputs["args"] == [2, 0] - task_inputs = _get_abstract_task_inputs(6, tg.get_node_value(6, "name"), result_object) - assert task_inputs["args"] == [2, 0] +@pytest.mark.asyncio +async def test_handle_completed_node_update(mocker): + import asyncio - task_inputs = _get_abstract_task_inputs(7, tg.get_node_value(7, "name"), result_object) - assert task_inputs["args"] == [0, 2] + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + detail = {} + next_groups = [0, 1] + mock_handle_cancelled = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_completed_node", return_value=next_groups + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) -@pytest.mark.asyncio -async def test_handle_completed_node(mocker): - """Unit test for completed node handler""" - pending_parents = {} + mock_increment = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.increment" + ) - result_object = get_mock_result() + async def get_task_group(dispatch_id, gid): + return [gid] - # tg edges are (1, 0), (0, 2) - pending_parents[0] = 1 - pending_parents[1] = 0 - pending_parents[2] = 1 + mock_get_sorted_task_groups = mocker.patch( + "covalent_dispatcher._core.dispatcher._sorted_task_groups.get_task_group", + get_task_group, + ) + mock_submit_task_group = mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_task_group" + ) + + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_decrement.assert_awaited() + assert mock_increment.await_count == 2 + assert mock_submit_task_group.await_count == 2 + + +@pytest.mark.asyncio +async def test_handle_cancelled_node_update(mocker): + import asyncio - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" + dispatch_id = "mock_dispatch" + node_id = 0 + status = Result.CANCELLED + detail = {} + mock_handle_cancelled = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_cancelled_node", + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" ) - node_result = {"node_id": 1, "status": Result.COMPLETED} + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_handle_cancelled.assert_awaited_with(dispatch_id, 0) + mock_decrement.assert_awaited() + + +@pytest.mark.asyncio +async def test_run_handle_failed_node_update(mocker): + import asyncio + + dispatch_id = "mock_dispatch" + node_id = 0 + status = Result.FAILED + detail = {} + mock_handle_failed = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_failed_node", + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) - next_nodes = await _handle_completed_node(result_object, 1, pending_parents) - assert next_nodes == [0] - assert pending_parents == {0: 0, 1: 0, 2: 1} + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_handle_failed.assert_awaited_with(dispatch_id, 0) + mock_decrement.assert_awaited() @pytest.mark.asyncio -async def test_handle_failed_node(mocker): - """Unit test for failed node handler""" - pending_parents = {} +async def test_run_handle_sublattice_node_update(mocker): + import asyncio - result_object = get_mock_result() - # tg edges are (1, 0), (0, 2) + from covalent._shared_files.util_classes import RESULT_STATUS - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" + dispatch_id = "mock_dispatch" + node_id = 0 + status = RESULT_STATUS.DISPATCHING + detail = {"sub_dispatch_id": "sub_dispatch"} + mock_run_dispatch = mocker.patch( + "covalent_dispatcher._core.dispatcher.run_dispatch", ) - await _handle_failed_node(result_object, 1) - - mock_upsert_lattice.assert_called() + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_run_dispatch.assert_called_with("sub_dispatch") + mock_decrement.assert_not_awaited() +@pytest.mark.parametrize("unresolved_count", [1, 0]) @pytest.mark.asyncio -async def test_handle_cancelled_node(mocker): - """Unit test for cancelled node handler""" - pending_parents = {} +async def test_handle_event(mocker, unresolved_count): + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", + ) + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + ) - result_object = get_mock_result() - # tg edges are (1, 0), (0, 2) + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + return_value=Result.COMPLETED, + ) - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=unresolved_count, ) - node_result = {"node_id": 1, "status": Result.CANCELLED} + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} - await _handle_cancelled_node(result_object, 1) - assert result_object._task_cancelled is True - mock_upsert_lattice.assert_called() + await _handle_event(msg) + + if unresolved_count < 1: + mock_finalize.assert_awaited() + mock_persist.assert_awaited() + else: + mock_finalize.assert_not_awaited() @pytest.mark.asyncio -async def test_get_initial_tasks_and_deps(mocker): - """Test internal function for initializing status_queue and pending_parents""" - pending_parents = {} +async def test_handle_event_exception(mocker): + import asyncio - result_object = get_mock_result() - num_tasks, initial_nodes, pending_parents = await _get_initial_tasks_and_deps(result_object) + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", + side_effect=RuntimeError(), + ) + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + return_value=Result.FAILED, + ) - assert initial_nodes == [1] - assert pending_parents == {0: 1, 1: 0, 2: 1, 3: 2} - assert num_tasks == len(result_object.lattice.transport_graph._graph.nodes) + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + return_value=Result.COMPLETED, + ) + + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=2, + ) + + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} + + _futures = {dispatch_id: asyncio.Future()} -@pytest.mark.asyncio -async def test_run_dispatch(mocker): - """ - Test running a mock dispatch - """ - res = get_mock_result() mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", return_value=res + "covalent_dispatcher._core.dispatcher._futures", + _futures, ) - mock_run = mocker.patch("covalent_dispatcher._core.dispatcher.run_workflow") - run_dispatch(res.dispatch_id) - mock_run.assert_called_with(res) + + assert await _handle_event(msg) == Result.FAILED + + assert _futures[dispatch_id].result() == Result.FAILED + + mock_persist.assert_awaited() + mock_finalize.assert_not_awaited() @pytest.mark.asyncio -async def test_run_workflow_normal(mocker): - """ - Test a normal workflow execution - """ +async def test_handle_event_finalize_exception(mocker): import asyncio - result_object = get_mock_result() - msg_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", ) - mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") - mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", return_value=result_object + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + return_value=Result.FAILED, ) - mock_persist = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.persist_result") - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + side_effect=RuntimeError(), + ) + + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=0, ) - await run_workflow(result_object) - mock_persist.assert_awaited_with(result_object.dispatch_id) - mock_unregister.assert_called_with(result_object.dispatch_id) + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} + _futures = {dispatch_id: asyncio.Future()} + mocker.patch( + "covalent_dispatcher._core.dispatcher._futures", + _futures, + ) + + assert await _handle_event(msg) == Result.FAILED + + assert _futures[dispatch_id].result() == Result.FAILED + + mock_persist.assert_awaited() + + +@pytest.mark.parametrize( + "failed,cancelled,final_status", + [ + (False, False, Result.COMPLETED), + (False, True, Result.CANCELLED), + (True, False, Result.FAILED), + (True, True, Result.FAILED), + ], +) @pytest.mark.asyncio -async def test_run_completed_workflow(mocker): - """ - Test run completed workflow - """ - import asyncio +async def test_finalize_dispatch(mocker, failed, cancelled, final_status): + mock_clear = mocker.patch("covalent_dispatcher._core.dispatcher._clear_caches") + failed_tasks = [(0, "task_0")] if failed else [] + cancelled_tasks = [(1, "task_1")] if cancelled else [] + + query_result = {"failed": failed_tasks, "cancelled": cancelled_tasks} + mock_incomplete = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get_incomplete_tasks", + return_value=query_result, + ) + + mock_dispatch_info = {"status": Result.COMPLETED} - result_object = get_mock_result() - result_object._status = Result.COMPLETED - msg_queue = asyncio.Queue() - mock_get_status_queue = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue + def mock_gen_dispatch_result(dispatch_id, **kwargs): + return {"status": kwargs["status"]} + + async def mock_dispatch_update(dispatch_id, dispatch_result): + mock_dispatch_info["status"] = dispatch_result["status"] + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.generate_dispatch_result", + mock_gen_dispatch_result, ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + mock_dispatch_update, ) - mock_plan = mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") + mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", return_value=result_object + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value=mock_dispatch_info, ) - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.persist_result") - await run_workflow(result_object) + dispatch_id = "dispatch_1" - mock_plan.assert_not_called() - mock_get_status_queue.assert_not_called() - mock_unregister.assert_called_with(result_object.dispatch_id) + assert await _finalize_dispatch(dispatch_id) == final_status @pytest.mark.asyncio -async def test_run_workflow_exception(mocker): - """ - Test any exception raised when running workflow - """ - import asyncio +async def test_submit_initial_tasks(mocker): + dispatch_id = "dispatch_1" - result_object = get_mock_result() - msg_queue = asyncio.Queue() + initial_groups = [1, 2] + sorted_groups = {1: [1], 2: [2]} mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue + "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", + return_value=(initial_groups, {1: 0, 2: 0}, sorted_groups), ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.generate_dispatch_result", ) - mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", - return_value=result_object, - side_effect=RuntimeError("Error"), + "covalent_dispatcher._core.dispatcher._initialize_caches", ) - mock_persist = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.persist_result") - result = await run_workflow(result_object) + mock_inc = mocker.patch("covalent_dispatcher._core.dispatcher._unresolved_tasks.increment") + mock_submit_task_group = mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_task_group", + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + ) - assert result.status == Result.FAILED - mock_persist.assert_awaited_with(result_object.dispatch_id) - mock_unregister.assert_called_with(result_object.dispatch_id) + assert await _submit_initial_tasks(dispatch_id) == Result.RUNNING + + assert mock_submit_task_group.await_count == 2 + assert mock_inc.await_count == 2 @pytest.mark.asyncio -async def test_run_planned_workflow_cancelled_update(mocker): - """ - Test run planned workflow with cancelled update - """ - import asyncio +async def test_submit_task_group(mocker): + dispatch_id = "dispatch_1" + gid = 2 + nodes = [4, 3, 2] + + mock_get_abs_input = mocker.patch( + "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", + return_value={"args": [], "kwargs": {}}, + ) - result_object = get_mock_result() + mock_attrs = { + "name": "task", + "value": 5, + "executor": "local", + "executor_data": {}, + } + + mock_statuses = [ + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + ] - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, ) - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + return_value=mock_statuses, + ) - def side_effect(result_object, node_id): - result_object._task_cancelled = True + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) - mock_handle_cancelled = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_cancelled_node", side_effect=side_effect + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", ) - status_queue = asyncio.Queue() - status_queue.put_nowait((0, Result.CANCELLED, {})) - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_cancelled.assert_awaited_with(result_object, 0) + + await _submit_task_group(dispatch_id, nodes, gid) + mock_run_abs_task.assert_called() + assert mock_get_abs_input.await_count == len(nodes) @pytest.mark.asyncio -async def test_run_planned_workflow_failed_update(mocker): - """ - Test run planned workflow with mocking a failed job update - """ - import asyncio +async def test_submit_task_group_skips_reusable(mocker): + """Check that submit_task_group skips reusable groups""" + dispatch_id = "dispatch_1" + gid = 2 + nodes = [4, 3, 2] + + mock_get_abs_input = mocker.patch( + "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", + return_value={"args": [], "kwargs": {}}, + ) - result_object = get_mock_result() + mock_attrs = { + "name": "task", + "value": 5, + "executor": "local", + "executor_data": {}, + } + + mock_statuses = [ + {"status": Result.PENDING_REUSE}, + {"status": Result.PENDING_REUSE}, + {"status": Result.PENDING_REUSE}, + ] - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, ) - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + return_value=mock_statuses, + ) - def side_effect(result_object, node_id): - result_object._task_failed = True + mock_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) - mock_handle_failed = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_failed_node", side_effect=side_effect + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", ) - status_queue = asyncio.Queue() - status_queue.put_nowait((0, Result.FAILED, {})) - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_failed.assert_awaited_with(result_object, 0) + + await _submit_task_group(dispatch_id, nodes, gid) + mock_run_abs_task.assert_not_called() + mock_get_abs_input.assert_not_awaited() + assert mock_update.await_count == len(nodes) @pytest.mark.asyncio -async def test_run_planned_workflow_dispatching(mocker): - """Test the run planned workflow for a dispatching node.""" - import asyncio +async def test_submit_parameter(mocker): + from covalent._shared_files.defaults import parameter_prefix + + dispatch_id = "dispatch_1" + node_id = 2 - result_object = get_mock_result() + mock_attrs = { + "name": parameter_prefix, + "value": 5, + "executor": "local", + "executor_data": {}, + } - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", ) - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", + ) + await _submit_task_group(dispatch_id, [node_id], node_id) - def side_effect(result_object, node_id): - result_object._task_failed = True + mock_run_abs_task.assert_not_called() + mock_update.assert_awaited() - mock_handle_failed = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_failed_node", side_effect=side_effect + +@pytest.mark.asyncio +async def test_clear_caches(mocker): + import networkx as nx + + g = nx.MultiDiGraph() + g.add_node(0, task_group_id=0) + g.add_node(1, task_group_id=0) + g.add_node(2, task_group_id=0) + g.add_node(3, task_group_id=3) + + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes_links") + mocker.patch("networkx.readwrite.node_link_graph", return_value=g) + mock_unresolved_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.remove" ) - mock_run_dispatch = mocker.patch("covalent_dispatcher._core.dispatcher.run_dispatch") - status_queue = asyncio.Queue() - status_queue.put_nowait( - (0, RESULT_STATUS.DISPATCHING_SUBLATTICE, {"sub_dispatch_id": "mock_sub_dispatch_id"}) + mock_pending_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._pending_parents.remove" ) - status_queue.put_nowait((0, RESULT_STATUS.FAILED, {})) # This ensures that the loop is exited. - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_failed.assert_awaited_with(result_object, 0) - mock_run_dispatch.assert_called_once_with("mock_sub_dispatch_id") + + mock_groups_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._sorted_task_groups.remove" + ) + + await _clear_caches("dispatch") + + assert mock_unresolved_remove.await_count == 1 + assert mock_pending_remove.await_count == 2 + assert mock_groups_remove.await_count == 2 @pytest.mark.asyncio @@ -486,19 +689,13 @@ async def test_cancel_dispatch(mocker): sub_dispatch_id = "sub_pipeline_workflow" sub_res._dispatch_id = sub_dispatch_id - def mock_get_result_object(dispatch_id): - objs = {res._dispatch_id: res, sub_res._dispatch_id: sub_res} - return objs[dispatch_id] - - mock_data_cancel = mocker.patch("covalent_dispatcher._core.dispatcher.set_cancel_requested") + mock_data_cancel = mocker.patch( + "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" + ) - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") + mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") mock_runner.cancel_tasks = AsyncMock() - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", mock_get_result_object - ) - res._initialize_nodes() sub_res._initialize_nodes() @@ -506,6 +703,33 @@ def mock_get_result_object(dispatch_id): tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) sub_tg = sub_res.lattice.transport_graph + async def mock_get_nodes(dispatch_id): + if dispatch_id == res.dispatch_id: + return list(tg._graph.nodes) + else: + return list(sub_tg._graph.nodes) + + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes) + + node_attrs = [ + {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes + ] + sub_node_attrs = [ + {"sub_dispatch_id": sub_tg.get_node_value(i, "sub_dispatch_id")} + for i in sub_tg._graph.nodes + ] + + async def mock_get(dispatch_id, task_ids, keys): + if dispatch_id == res.dispatch_id: + return node_attrs + else: + return sub_node_attrs + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + mock_get, + ) + await cancel_dispatch("pipeline_workflow") task_ids = list(tg._graph.nodes) @@ -522,30 +746,51 @@ async def test_cancel_dispatch_with_task_ids(mocker): res = get_mock_result() sub_res = get_mock_result() + res._initialize_nodes() + sub_res._initialize_nodes() + sub_dispatch_id = "sub_pipeline_workflow" sub_res._dispatch_id = sub_dispatch_id + tg = res.lattice.transport_graph + tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) + sub_tg = sub_res.lattice.transport_graph - def mock_get_result_object(dispatch_id): - objs = {res._dispatch_id: res, sub_res._dispatch_id: sub_res} - return objs[dispatch_id] - - mock_data_cancel = mocker.patch("covalent_dispatcher._core.dispatcher.set_cancel_requested") + mock_data_cancel = mocker.patch( + "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" + ) - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") + mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") mock_runner.cancel_tasks = AsyncMock() + async def mock_get_nodes(dispatch_id): + if dispatch_id == res.dispatch_id: + return list(tg._graph.nodes) + else: + return list(sub_tg._graph.nodes) + + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes) + + node_attrs = [ + {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes + ] + sub_node_attrs = [ + {"sub_dispatch_id": sub_tg.get_node_value(i, "sub_dispatch_id")} + for i in sub_tg._graph.nodes + ] + + async def mock_get(dispatch_id, task_ids, keys): + if dispatch_id == res.dispatch_id: + return node_attrs + else: + return sub_node_attrs + mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", mock_get_result_object + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + mock_get, ) - mock_app_log = mocker.patch("covalent_dispatcher._core.dispatcher.app_log.debug") - res._initialize_nodes() - sub_res._initialize_nodes() - - tg = res.lattice.transport_graph - tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) - sub_tg = sub_res.lattice.transport_graph - task_ids = list(tg._graph.nodes) + mock_app_log = mocker.patch("covalent_dispatcher._core.dispatcher.app_log.debug") + task_ids = [2] sub_task_ids = list(sub_tg._graph.nodes) await cancel_dispatch("pipeline_workflow", task_ids) @@ -554,34 +799,3 @@ def mock_get_result_object(dispatch_id): mock_data_cancel.assert_has_awaits(calls) mock_runner.cancel_tasks.assert_has_awaits(calls) assert mock_app_log.call_count == 2 - - -@pytest.mark.asyncio -async def test_submit_task(mocker): - """Test the submit task function.""" - - def transport_graph_get_value_side_effect(node_id, key): - if key == "name": - return "mock-name" - if key == "status": - return RESULT_STATUS.COMPLETED - - mock_result = MagicMock() - mock_result.lattice.transport_graph.get_node_value.side_effect = ( - transport_graph_get_value_side_effect - ) - - generate_node_result_mock = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.generate_node_result" - ) - update_node_result_mock = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.update_node_result" - ) - await _submit_task(mock_result, 0) - assert mock_result.lattice.transport_graph.get_node_value.mock_calls == [ - call(0, "name"), - call(0, "status"), - call(0, "output"), - ] - update_node_result_mock.assert_called_with(mock_result, generate_node_result_mock.return_value) - generate_node_result_mock.assert_called_once() diff --git a/tests/covalent_dispatcher_tests/_core/execution_test.py b/tests/covalent_dispatcher_tests/_core/execution_test.py index 40d78d55d6..474190a879 100644 --- a/tests/covalent_dispatcher_tests/_core/execution_test.py +++ b/tests/covalent_dispatcher_tests/_core/execution_test.py @@ -26,13 +26,17 @@ from typing import Dict, List import pytest +import pytest_asyncio +from sqlalchemy.pool import StaticPool import covalent as ct from covalent._results_manager import Result from covalent._workflow.lattice import Lattice from covalent_dispatcher._core.dispatcher import run_workflow from covalent_dispatcher._core.execution import _get_task_inputs -from covalent_dispatcher._db import update +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import models, update from covalent_dispatcher._db.datastore import DataStore TEST_RESULTS_DIR = "/tmp/results" @@ -45,9 +49,19 @@ def test_db(): return DataStore( db_URL="sqlite+pysqlite:///:memory:", initialize_db=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, ) +@pytest_asyncio.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + def get_mock_result() -> Result: """Construct a mock result object corresponding to a lattice.""" @@ -72,7 +86,19 @@ def pipeline(x): return result_object -def test_get_task_inputs(): +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + with test_db.session() as session: + record = session.query(models.Lattice).where(models.Lattice.id == 1).first() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id, bare=False) + + +@pytest.mark.asyncio +async def test_get_task_inputs(mocker, test_db): """Test _get_task_inputs for both dicts and list parameter types""" @ct.electron @@ -119,14 +145,23 @@ def multivar_workflow(x, y): list_workflow.build_graph([1, 2, 3]) serialized_args = [ct.TransportableObject(i) for i in [1, 2, 3]] - tg = list_workflow.transport_graph + # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 + sdkres = Result(lattice=list_workflow, dispatch_id="asdf") + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph tg.set_node_value(2, "output", ct.TransportableObject(1)) tg.set_node_value(3, "output", ct.TransportableObject(2)) tg.set_node_value(4, "output", ct.TransportableObject(3)) - result_object = Result(lattice=list_workflow, dispatch_id="asdf") - task_inputs = _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) expected_inputs = {"args": serialized_args, "kwargs": {}} @@ -136,13 +171,18 @@ def multivar_workflow(x, y): dict_workflow.build_graph({"a": 1, "b": 2}) serialized_args = {"a": ct.TransportableObject(1), "b": ct.TransportableObject(2)} - tg = dict_workflow.transport_graph + # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 + sdkres = Result(lattice=dict_workflow, dispatch_id="asdf_dict_workflow") + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph tg.set_node_value(2, "output", ct.TransportableObject(1)) tg.set_node_value(3, "output", ct.TransportableObject(2)) - result_object = Result(lattice=dict_workflow, dispatch_id="asdf") - task_inputs = _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) expected_inputs = {"args": [], "kwargs": serialized_args} assert task_inputs == expected_inputs @@ -150,279 +190,90 @@ def multivar_workflow(x, y): # Check arg order multivar_workflow.build_graph(1, 2) received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) - result_object = Result(lattice=received_lattice, dispatch_id="asdf") - tg = received_lattice.transport_graph + sdkres = Result(lattice=received_lattice, dispatch_id="asdf_multivar_workflow") + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph - assert list(tg._graph.nodes) == list(range(9)) + # Account for injected postprocess electron + assert list(tg._graph.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] tg.set_node_value(0, "output", ct.TransportableObject(1)) tg.set_node_value(2, "output", ct.TransportableObject(2)) - task_inputs = _get_task_inputs(4, tg.get_node_value(4, "name"), result_object) + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + task_inputs = await _get_task_inputs(4, tg.get_node_value(4, "name"), result_object) input_args = [arg.get_deserialized() for arg in task_inputs["args"]] assert input_args == [1, 2] - task_inputs = _get_task_inputs(5, tg.get_node_value(5, "name"), result_object) - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [2, 1] + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) - task_inputs = _get_task_inputs(6, tg.get_node_value(6, "name"), result_object) + task_inputs = await _get_task_inputs(5, tg.get_node_value(5, "name"), result_object) input_args = [arg.get_deserialized() for arg in task_inputs["args"]] assert input_args == [2, 1] - task_inputs = _get_task_inputs(7, tg.get_node_value(7, "name"), result_object) - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [1, 2] - - -@pytest.mark.asyncio -async def test_run_workflow_with_failing_nonleaf(mocker): - """Test running workflow with a failing intermediate node""" - - @ct.electron - def failing_task(x): - assert False - - @ct.lattice - def workflow(x): - res1 = failing_task(x) - res2 = failing_task(res1) - return res2 - - from covalent._workflow.lattice import Lattice - - workflow.build_graph(5) - - json_lattice = workflow.serialize_to_json() - dispatch_id = "asdf" - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice) - result_object._dispatch_id = dispatch_id - result_object._root_dispatch_id = dispatch_id - result_object._initialize_nodes() - - # patch all methods that reference a DB - mocker.patch("covalent_dispatcher._db.upsert._lattice_data") - mocker.patch("covalent_dispatcher._db.upsert._electron_data") - mocker.patch("covalent_dispatcher._db.update.persist") - mocker.patch( - "covalent._results_manager.result.Result._get_node_name", return_value="failing_task" - ) - mocker.patch( - "covalent._results_manager.result.Result._get_node_error", return_value="AssertionError" - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( + mock_get_result = mocker.patch( "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object ) - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_get_failed_nodes = mocker.patch( - "covalent._results_manager.result.Result._get_failed_nodes", - return_value=[(0, "failing_task")], - ) - - update.persist(result_object) - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - assert result_object.status == Result.FAILED - mock_get_failed_nodes.assert_called() - assert result_object._error == "The following tasks failed:\n0: failing_task" - - -@pytest.mark.asyncio -async def test_run_workflow_with_failing_leaf(mocker): - """Test running workflow with a failing leaf node""" - - @ct.electron - def failing_task(x): - assert False - return x - - @ct.lattice - def workflow(x): - res1 = failing_task(x) - return res1 - - from covalent._workflow.lattice import Lattice - - workflow.build_graph(5) - - json_lattice = workflow.serialize_to_json() - dispatch_id = "asdf" - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice) - result_object._dispatch_id = dispatch_id - result_object._root_dispatch_id = dispatch_id - result_object._initialize_nodes() + task_inputs = await _get_task_inputs(6, tg.get_node_value(6, "name"), result_object) + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [2, 1] - mocker.patch("covalent_dispatcher._db.upsert._lattice_data") - mocker.patch("covalent_dispatcher._db.upsert._electron_data") - mocker.patch("covalent_dispatcher._db.update.persist") - mocker.patch( - "covalent._results_manager.result.Result._get_node_name", return_value="failing_task" - ) - mocker.patch( - "covalent._results_manager.result.Result._get_node_error", return_value="AssertionError" - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( + mock_get_result = mocker.patch( "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object ) - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_get_failed_nodes = mocker.patch( - "covalent._results_manager.result.Result._get_failed_nodes", - return_value=[(0, "failing_task")], - ) - - update.persist(result_object) - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - assert result_object.status == Result.FAILED - assert result_object._error == "The following tasks failed:\n0: failing_task" + task_inputs = await _get_task_inputs(7, tg.get_node_value(7, "name"), result_object) + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [1, 2] -@pytest.mark.asyncio async def test_run_workflow_does_not_deserialize(mocker): """Check that dispatcher does not deserialize user data when using out-of-process `workflow_executor`""" - @ct.electron(executor="local") + from dask.distributed import LocalCluster + + from covalent._workflow.lattice import Lattice + from covalent.executor import DaskExecutor + + lc = LocalCluster() + dask_exec = DaskExecutor(lc.scheduler_address) + + @ct.electron(executor=dask_exec) def task(x): return x - @ct.lattice(executor="local", workflow_executor="local") + @ct.lattice(executor=dask_exec, workflow_executor=dask_exec) def workflow(x): # Exercise both sublatticing and postprocessing - sublattice_task = ct.lattice(task, workflow_executor="local") - res1 = ct.electron(sublattice_task(x), executor="local") + sublattice_task = ct.lattice(task, workflow_executor=dask_exec) + res1 = ct.electron(sublattice_task(x), executor=dask_exec) return res1 - dispatch_id = "asdf" workflow.build_graph(5) json_lattice = workflow.serialize_to_json() + dispatch_id = "asdf" lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice, dispatch_id=dispatch_id) + result_object = Result(lattice, lattice.metadata["results_dir"]) + result_object._dispatch_id = dispatch_id result_object._initialize_nodes() - mocker.patch("covalent_dispatcher._db.upsert._lattice_data") - mocker.patch("covalent_dispatcher._db.upsert._electron_data") - mocker.patch("covalent_dispatcher._db.update.persist") - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mock_run_abstract_task = mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") + mocker.patch("covalent_dispatcher._db.datastore.DataStore.factory", return_value=test_db) mocker.patch( "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object ) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - update.persist(result_object) mock_to_deserialize = mocker.patch("covalent.TransportableObject.get_deserialized") - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) + status = await run_workflow(result_object.dispatch_id) mock_to_deserialize.assert_not_called() - assert result_object.status == Result.RUNNING - assert mock_run_abstract_task.call_count == 1 - - -@pytest.mark.asyncio -async def test_run_workflow_with_client_side_postprocess(test_db, mocker): - """Check that run_workflow handles "client" workflow_executor for - postprocessing""" - - dispatch_id = "asdf" - result_object = get_mock_result() - result_object.lattice.set_metadata("workflow_executor", "client") - result_object._dispatch_id = dispatch_id - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mocker.patch("covalent_dispatcher._core.runner._gather_deps") - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - mock_run_abstract_task = mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - - update.persist(result_object) - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - assert result_object.status == Result.RUNNING - assert mock_run_abstract_task.call_count == 1 - - -@pytest.mark.asyncio -async def test_run_workflow_with_failed_postprocess(test_db, mocker): - """Check that run_workflow handles postprocessing failures""" - - dispatch_id = "asdf" - result_object = get_mock_result() - result_object._dispatch_id = dispatch_id - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - - update.persist(result_object) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_run_abstract_task = mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - - def failing_workflow(x): - assert False - - result_object.lattice.set_metadata("workflow_executor", "bogus") - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - - assert result_object.status == Result.RUNNING - - result_object.lattice.workflow_function = ct.TransportableObject(failing_workflow) - result_object.lattice.set_metadata("workflow_executor", "local") - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - - assert result_object.status == Result.RUNNING - assert mock_run_abstract_task.call_count == 2 + assert status == Result.COMPLETED diff --git a/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py new file mode 100644 index 0000000000..15b8acb282 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py @@ -0,0 +1,127 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the core functionality of the runner. +""" + + +import pytest +from sqlalchemy.pool import StaticPool + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.runner import _gather_deps +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import update +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id) + + +@pytest.mark.asyncio +async def test_gather_deps(mocker, test_db): + """Test internal _gather_deps for assembling deps into call_before and + call_after""" + + def square(x): + return x * x + + @ct.electron( + deps_bash=ct.DepsBash("ls -l"), + deps_pip=ct.DepsPip(["pandas"]), + call_before=[ct.DepsCall(square, [5])], + call_after=[ct.DepsCall(square, [3])], + ) + def task(x): + return x + + @ct.lattice + def workflow(x): + return task(x) + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + workflow.build_graph(5) + + received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) + sdkres = Result(received_workflow, "test_gather_deps") + result_object = get_mock_srvresult(sdkres, test_db) + + async def get_electron_attrs(dispatch_id, node_id, keys): + return { + key: result_object.lattice.transport_graph.get_node_value(node_id, key) for key in keys + } + + mocker.patch( + "covalent_dispatcher._core.data_manager.electron.get", + get_electron_attrs, + ) + + before, after = await _gather_deps(result_object.dispatch_id, 0) + assert len(before) == 3 + assert len(after) == 1 diff --git a/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py b/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py new file mode 100644 index 0000000000..8c1b24a55e --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py @@ -0,0 +1,106 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the cancellation module +""" + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core.runner_modules import cancel + + +@pytest.mark.asyncio +async def test_cancel_tasks(mocker): + """Test the public `cancel_tasks` function""" + dispatch_id = "test_cancel_tasks" + node_id = 0 + mock_node_metadata = [{"executor": "dask", "executor_data": {}}] + mock_job_metadata = [{"job_handle": 42}] + mock_cancel_priv = mocker.patch("covalent_dispatcher._core.runner_modules.cancel._cancel_task") + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel._get_metadata_for_nodes", + return_value=mock_node_metadata, + ) + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_metadata, + ) + + await cancel.cancel_tasks(dispatch_id, [node_id]) + + assert mock_cancel_priv.call_count == 1 + + +@pytest.mark.asyncio +async def test_cancel_task_priv(mocker): + """Test the internal `_cancel_task` function""" + mock_executor = MagicMock() + mock_executor._cancel = AsyncMock(return_value=True) + mock_set_status = mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.set_job_status" + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel.get_executor", return_value=mock_executor + ) + + dispatch_id = "test_cancel_task_priv" + job_handle = json.dumps(42) + task_id = 0 + + await cancel._cancel_task(dispatch_id, task_id, ["dask", {}], job_handle) + + task_meta = {"dispatch_id": dispatch_id, "node_id": task_id} + + mock_executor._cancel.assert_awaited_with(task_meta, 42) + + mock_set_status.assert_awaited_with(dispatch_id, task_id, str(RESULT_STATUS.CANCELLED)) + + +@pytest.mark.asyncio +async def test_cancel_task_priv_exception(mocker): + """Test the internal `_cancel_task` function""" + mock_executor = MagicMock() + mock_executor._cancel = AsyncMock(side_effect=RuntimeError()) + mock_set_status = mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.set_job_status" + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel.get_executor", return_value=mock_executor + ) + + dispatch_id = "test_cancel_task_priv" + job_handle = json.dumps(42) + task_id = 0 + + await cancel._cancel_task(dispatch_id, task_id, ["dask", {}], job_handle) + + task_meta = {"dispatch_id": dispatch_id, "node_id": task_id} + + mock_executor._cancel.assert_awaited_with(task_meta, 42) + + mock_set_status.assert_not_awaited() diff --git a/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py b/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py new file mode 100644 index 0000000000..a20482a262 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py @@ -0,0 +1,103 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the executor proxy handlers to get/set job info +""" + +from unittest.mock import MagicMock + +import pytest + +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core.runner_modules import jobs + + +@pytest.mark.asyncio +async def test_get_cancel_requested(mocker): + dispatch_id = "test_get_cancel_requested" + mock_job_records = [{"cancel_requested": True}] + + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_records, + ) + + assert await jobs.get_cancel_requested(dispatch_id, 0) is True + + +@pytest.mark.asyncio +async def test_get_version_info(mocker): + dispatch_id = "test_get_version_info" + mock_ver_info = {"python_version": "3.10", "covalent_version": "0.220"} + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.datamgr.lattice.get", + return_value=mock_ver_info, + ) + assert await jobs.get_version_info(dispatch_id, 0) == {"python": "3.10", "covalent": "0.220"} + + +@pytest.mark.asyncio +async def test_get_job_status(mocker): + dispatch_id = "test_job_status" + mock_job_records = [{"status": str(RESULT_STATUS.RUNNING)}] + + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_records, + ) + + assert await jobs.get_job_status(dispatch_id, 0) == RESULT_STATUS.RUNNING + + +@pytest.mark.asyncio +async def test_put_job_handle(mocker): + dispatch_id = "test_put_job_handle" + task_id = 0 + job_handle = "jobArn" + + mock_set = mocker.patch("covalent_dispatcher._core.data_modules.job_manager.set_job_handle") + + assert await jobs.put_job_handle(dispatch_id, task_id, job_handle) is True + mock_set.assert_awaited_with(dispatch_id, task_id, job_handle) + + +@pytest.mark.asyncio +async def test_put_job_status(mocker): + dispatch_id = "test_put_job_handle" + task_id = 0 + status = RESULT_STATUS.RUNNING + + mock_exec_attrs = {"executor": "dask", "executor_data": {}} + executor = MagicMock() + executor.validate_status = MagicMock(return_value=True) + + mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get", return_value=mock_exec_attrs + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.get_executor", return_value=executor + ) + mock_set = mocker.patch("covalent_dispatcher._core.data_modules.job_manager.set_job_status") + + assert await jobs.put_job_status(dispatch_id, task_id, status) is True + mock_set.assert_awaited_with(dispatch_id, task_id, str(status)) diff --git a/tests/covalent_dispatcher_tests/_core/runner_test.py b/tests/covalent_dispatcher_tests/_core/runner_test.py index 12103824eb..17f19e0b54 100644 --- a/tests/covalent_dispatcher_tests/_core/runner_test.py +++ b/tests/covalent_dispatcher_tests/_core/runner_test.py @@ -23,25 +23,14 @@ """ -import json from unittest.mock import AsyncMock, MagicMock import pytest -from mock import call import covalent as ct from covalent._results_manager import Result from covalent._workflow.lattice import Lattice -from covalent_dispatcher._core.runner import ( - _cancel_task, - _gather_deps, - _get_metadata_for_nodes, - _run_abstract_task, - _run_task, - cancel_tasks, - get_executor, -) -from covalent_dispatcher._core.runner_modules.executor_proxy import _get_cancel_requested +from covalent_dispatcher._core.runner import _run_abstract_task, _run_task from covalent_dispatcher._db.datastore import DataStore TEST_RESULTS_DIR = "/tmp/results" @@ -77,160 +66,91 @@ def pipeline(x): pipeline.build_graph(x="absolute") received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) result_object = Result(received_workflow, "pipeline_workflow") - result_object._initialize_nodes() return result_object -def test_get_executor(mocker): - """Test that get_executor returns the correct executor""" - - executor_manager_mock = mocker.patch("covalent_dispatcher._core.runner._executor_manager") - executor = get_executor(["local", {"mock-key": "mock-value"}], "mock-loop", "mock-pool") - assert executor_manager_mock.get_executor.mock_calls == [ - call("local"), - call().from_dict({"mock-key": "mock-value"}), - call()._init_runtime(loop="mock-loop", cancel_pool="mock-pool"), - ] - assert executor == executor_manager_mock.get_executor() - - -def test_gather_deps(): - """Test internal _gather_deps for assembling deps into call_before and - call_after""" - - def square(x): - return x * x - - @ct.electron( - deps_bash=ct.DepsBash("ls -l"), - deps_pip=ct.DepsPip(["pandas"]), - call_before=[ct.DepsCall(square, [5])], - call_after=[ct.DepsCall(square, [3])], - ) - def task(x): - return x - - @ct.lattice - def workflow(x): - return task(x) - - workflow.build_graph(5) - - received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = Result(received_workflow, "asdf") - - before, after = _gather_deps(result_object, 0) - assert len(before) == 3 - assert len(after) == 1 - - @pytest.mark.asyncio async def test_run_abstract_task_exception_handling(mocker): """Test that exceptions from resolving abstract inputs are handled""" - result_object = get_mock_result() + dispatch_id = "mock_dispatch" + inputs = {"args": [], "kwargs": {}} - mock_get_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - mock_get_task_input_values = mocker.patch( - "covalent_dispatcher._core.runner._get_task_input_values", - side_effect=RuntimeError(), + + mocker.patch("covalent_dispatcher._core.runner._gather_deps", side_effect=RuntimeError()) + mocker.patch( + "covalent_dispatcher._core.data_manager.electron.get", + return_value={"function": "function"}, ) node_result = await _run_abstract_task( - dispatch_id=result_object.dispatch_id, + dispatch_id=dispatch_id, node_id=0, node_name="test_node", abstract_inputs=inputs, - executor=["local", {}], + selected_executor=["local", {}], ) assert node_result["status"] == Result.FAILED @pytest.mark.asyncio -async def test_run_abstract_task_get_cancel_requested(mocker): - """Test that get_cancel_requested is properly handled""" - mock_result = MagicMock() - - result_object = get_mock_result() +async def test_run_task_runtime_exception_handling(mocker): inputs = {"args": [], "kwargs": {}} - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - mock_get_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - mock_get_task_input_values = mocker.patch( - "covalent_dispatcher._core.runner._get_task_input_values", - side_effect=RuntimeError(), - ) - mock_get_cancel_requested = mocker.patch( - "covalent_dispatcher._core.runner_modules.executor_proxy._get_cancel_requested", - return_value=AsyncMock(return_value=True), - ) - mock_generate_node_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.generate_node_result", - return_value=mock_result, + mock_executor = MagicMock() + mock_executor._execute = AsyncMock(return_value=("", "", "error", Result.FAILED)) + mock_get_executor = mocker.patch( + "covalent_dispatcher._core.runner.get_executor", + return_value=mock_executor, ) - node_result = await _run_abstract_task( - dispatch_id=result_object.dispatch_id, - node_id=0, - node_name="test_node", - abstract_inputs=inputs, - executor=["local", {}], - ) - - mock_get_result.assert_called_with(result_object.dispatch_id) - mock_get_cancel_requested.assert_awaited_once_with(result_object.dispatch_id, 0) - mock_generate_node_result.assert_called() - mock_app_log.assert_called_with(f"Don't run cancelled task {result_object.dispatch_id}:0") - assert node_result == mock_result - - -@pytest.mark.asyncio -async def test_run_task_executor_exception_handling(mocker): - """Test that exceptions from initializing executors are caught""" - - result_object = get_mock_result() - inputs = {"args": [], "kwargs": {}} - mock_get_executor = mocker.patch( - "covalent_dispatcher._core.runner._executor_manager.get_executor", - side_effect=Exception(), + dispatch_id = "mock_dispatch" + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, ) node_result = await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=1, inputs=inputs, serialized_callable=None, - executor=["nonexistent", {}], + selected_executor=["local", {}], call_before=[], call_after=[], - node_name="test_node", + node_name="task", ) + mock_executor._execute.assert_awaited_once() + assert node_result["status"] == Result.FAILED + assert node_result["stderr"] == "error" @pytest.mark.asyncio -async def test_run_task_runtime_exception_handling(mocker): - result_object = get_mock_result() +async def test_run_task_exception_handling(mocker): + dispatch_id = "mock_dispatch" inputs = {"args": [], "kwargs": {}} mock_executor = MagicMock() - mock_executor._execute = AsyncMock(return_value=("", "", "error", True)) + mock_executor._execute = AsyncMock(side_effect=RuntimeError("error")) + mock_get_executor = mocker.patch( - "covalent_dispatcher._core.runner._executor_manager.get_executor", + "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor, ) + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, + ) + mocker.patch("traceback.TracebackException.from_exception", return_value="error") node_result = await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=1, inputs=inputs, serialized_callable=None, - executor=["local", {}], + selected_executor=["local", {}], call_before=[], call_after=[], node_name="task", @@ -238,119 +158,35 @@ async def test_run_task_runtime_exception_handling(mocker): mock_executor._execute.assert_awaited_once() - assert node_result["stderr"] == "error" - - -@pytest.mark.asyncio -async def test__cancel_task(mocker): - """ - Test module private _cancel_task method - """ - mock_executor = AsyncMock() - mock_executor.from_dict = MagicMock() - mock_executor._init_runtime = MagicMock() - mock_executor._cancel = AsyncMock() - - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - get_executor_mock = mocker.patch( - "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor - ) - mock_set_cancel_result = mocker.patch("covalent_dispatcher._core.runner.set_cancel_result") - - dispatch_id = "abcd" - task_id = 0 - executor = "mock_executor" - executor_data = {} - job_handle = "42" - - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - - await _cancel_task(dispatch_id, task_id, executor, executor_data, job_handle) - - assert mock_app_log.call_count == 2 - get_executor_mock.assert_called_once() - mock_executor._cancel.assert_called_with(task_metadata, json.loads(job_handle)) - mock_set_cancel_result.assert_called() + assert node_result["status"] == Result.FAILED + assert node_result["error"] == "error" @pytest.mark.asyncio -async def test__cancel_task_exception(mocker): - """ - Test exception raised in module private _cancel task exception - """ - mock_executor = AsyncMock() - mock_executor.from_dict = MagicMock() - mock_executor._init_runtime = MagicMock() - mock_executor._cancel = AsyncMock(side_effect=Exception("cancel")) - - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - get_executor_mock = mocker.patch( - "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor - ) - mocker.patch("covalent_dispatcher._core.runner.set_cancel_result") - - dispatch_id = "abcd" - task_id = 0 - executor = "mock_executor" - executor_data = {} - job_handle = "42" - - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - - cancel_result = await _cancel_task(dispatch_id, task_id, executor, executor_data, job_handle) - assert mock_app_log.call_count == 3 - get_executor_mock.assert_called_once() - mock_executor._cancel.assert_called_with(task_metadata, json.loads(job_handle)) - assert cancel_result is False - +async def test_run_task_executor_exception_handling(mocker): + """Test that exceptions from initializing executors are caught""" -@pytest.mark.asyncio -async def test_cancel_tasks(mocker): - """ - Test cancelling multiple tasks - """ - mock_get_jobs_metadata = mocker.patch( - "covalent_dispatcher._core.runner.get_jobs_metadata", return_value=AsyncMock() - ) - mock_get_metadata_for_nodes = mocker.patch( - "covalent_dispatcher._core.runner._get_metadata_for_nodes", return_value=MagicMock() + dispatch_id = "mock_dispatch" + inputs = {"args": [], "kwargs": {}} + mock_get_executor = mocker.patch( + "covalent_dispatcher._core.runner.get_executor", + side_effect=Exception(), ) - dispatch_id = "abcd" - task_ids = [0, 1] - - await cancel_tasks(dispatch_id, task_ids) - - mock_get_jobs_metadata.assert_awaited_with(dispatch_id, task_ids) - mock_get_metadata_for_nodes.assert_called_with(dispatch_id, task_ids) - - -def test__get_metadata_for_nodes(mocker): - """ - Test module private method for getting nodes metadata - """ - dispatch_id = "abcd" - node_ids = [0, 1] - - mock_get_result_object = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=MagicMock() + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, ) - _get_metadata_for_nodes(dispatch_id, node_ids) - mock_get_result_object.assert_called_with(dispatch_id) - -@pytest.mark.asyncio -async def test__get_cancel_requested(mocker): - """ - Test module private method for querying if a task was requested to be cancelled - """ - dispatch_id = "abcd" - task_id = 0 - mock_get_jobs_metadata = mocker.patch( - "covalent_dispatcher._core.runner_modules.executor_proxy.job_manager.get_jobs_metadata", - return_value=AsyncMock(), + node_result = await _run_task( + dispatch_id=dispatch_id, + node_id=1, + inputs=inputs, + serialized_callable=None, + selected_executor=["nonexistent", {}], + call_before=[], + call_after=[], + node_name="test_node", ) - await _get_cancel_requested(dispatch_id, task_id) - - mock_get_jobs_metadata.assert_awaited_with(dispatch_id, [task_id]) + assert node_result["status"] == Result.FAILED diff --git a/tests/covalent_dispatcher_tests/_db/load_test.py b/tests/covalent_dispatcher_tests/_db/load_test.py deleted file mode 100644 index 853c7b558b..0000000000 --- a/tests/covalent_dispatcher_tests/_db/load_test.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the GNU Affero General Public License 3.0 (the "License"). -# A copy of the License may be obtained with this software package or at -# -# https://www.gnu.org/licenses/agpl-3.0.en.html -# -# Use of this file is prohibited except in compliance with the License. Any -# modifications or derivative works of this file must retain this copyright -# notice, and modified files must contain a notice indicating that they have -# been altered from the originals. -# -# Covalent is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. -# -# Relief from the License may be granted by purchasing a commercial license. - - -"""Unit tests for result loading (from database) module.""" - -from unittest.mock import call - -import pytest -from sqlalchemy import select - -import covalent as ct -from covalent._results_manager.result import Result as SDKResult -from covalent._shared_files.util_classes import Status -from covalent._workflow.lattice import Lattice as SDKLattice -from covalent_dispatcher._db import models, update -from covalent_dispatcher._db.datastore import DataStore -from covalent_dispatcher._db.load import ( - _result_from, - electron_record, - get_result_object_from_storage, - sublattice_dispatch_id, -) - - -@pytest.fixture -def test_db(): - """Instantiate and return an in-memory database.""" - - return DataStore( - db_URL="sqlite+pysqlite:///:memory:", - initialize_db=True, - ) - - -def get_mock_result(dispatch_id) -> SDKResult: - """Construct a mock result object corresponding to a lattice.""" - - @ct.electron - def task(x): - return x - - @ct.lattice - def workflow(x): - res1 = task(x) - return res1 - - workflow.build_graph(x=1) - received_workflow = SDKLattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = SDKResult(received_workflow, dispatch_id) - - return result_object - - -def test_result_from(mocker, test_db): - """Test the result from function in the load module.""" - - dispatch_id = "test_result_from" - res = get_mock_result(dispatch_id) - res._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - update.persist(res) - - with test_db.session() as session: - mock_lattice_record = session.scalars( - select(models.Lattice).where(models.Lattice.dispatch_id == dispatch_id) - ).first() - - result_object = _result_from(mock_lattice_record) - - assert result_object._root_dispatch_id == mock_lattice_record.root_dispatch_id - assert result_object._status == Status(mock_lattice_record.status) - assert result_object._error == "" - assert result_object.inputs == res.inputs - assert result_object._start_time == mock_lattice_record.started_at - assert result_object._end_time == mock_lattice_record.completed_at - assert result_object.result == res.result - - -def test_get_result_object_from_storage(mocker): - """Test the get_result_object_from_storage method.""" - from covalent_dispatcher._db.load import Lattice - - result_from_mock = mocker.patch("covalent_dispatcher._db.load._result_from") - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - result_object = get_result_object_from_storage("mock-dispatch-id") - - assert call(Lattice) in session_mock.query.mock_calls - session_mock.query().where().first.assert_called_once() - - assert result_object == result_from_mock.return_value - result_from_mock.assert_called_once_with(session_mock.query().where().first.return_value) - - -def test_get_result_object_from_storage_exception(mocker): - """Test the get_result_object_from_storage method.""" - from covalent_dispatcher._db.load import Lattice - - result_from_mock = mocker.patch("covalent_dispatcher._db.load._result_from") - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - session_mock.query().where().first.return_value = None - - with pytest.raises(RuntimeError): - get_result_object_from_storage("mock-dispatch-id") - - assert call(Lattice) in session_mock.query.mock_calls - session_mock.query().where().first.assert_called_once() - - result_from_mock.assert_not_called() - - -def test_electron_record(mocker): - """Test the electron_record method.""" - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - electron_record("mock-dispatch-id", "mock-node-id") - session_mock.query().filter().filter().filter().first.assert_called_once() - - -def test_sublattice_dispatch_id(mocker): - """Test the sublattice_dispatch_id method.""" - - class MockObject: - dispatch_id = "mock-dispatch-id" - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - session_mock.query().filter().first.return_value = MockObject() - res = sublattice_dispatch_id("mock-electron-id") - assert res == "mock-dispatch-id" - - session_mock.query().filter().first.return_value = [] - res = sublattice_dispatch_id("mock-electron-id") - assert res is None diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index 48db5069f5..5c563a5bfc 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -21,17 +21,21 @@ """Unit tests for the FastAPI app.""" import json -import os +import tempfile from contextlib import contextmanager from typing import Generator +from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.orm import Session, declarative_base, sessionmaker -from covalent._results_manager.result import Result +import covalent as ct +from covalent._dispatcher_plugins.local import LocalDispatcher +from covalent._shared_files.util_classes import RESULT_STATUS from covalent_dispatcher._db.dispatchdb import DispatchDB +from covalent_dispatcher._service.app import _try_get_result_object from covalent_ui.app import fastapi_app as fast_app DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" @@ -78,6 +82,25 @@ def test_db(): return MockDataStore(db_URL="sqlite+pysqlite:///:memory:") +@pytest.fixture +def mock_manifest(): + """Create a mock workflow manifest""" + + @ct.electron + def task(x): + return x**2 + + @ct.lattice + def workflow(x): + return task(x) + + workflow.build_graph(3) + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + return manifest + + @pytest.fixture def test_db_file(): """Instantiate and return a database.""" @@ -85,61 +108,34 @@ def test_db_file(): @pytest.mark.asyncio -@pytest.mark.parametrize("disable_run", [True, False]) -async def test_submit(mocker, client, disable_run): +async def test_submit(mocker, client): """Test the submit endpoint.""" mock_data = json.dumps({}).encode("utf-8") run_dispatcher_mock = mocker.patch( - "covalent_dispatcher.run_dispatcher", return_value=DISPATCH_ID + "covalent_dispatcher.entry_point.make_dispatch", return_value=DISPATCH_ID ) - response = client.post("/api/submit", data=mock_data, params={"disable_run": disable_run}) + response = client.post("/api/v1/dispatch/submit", data=mock_data) assert response.json() == DISPATCH_ID - run_dispatcher_mock.assert_called_once_with(mock_data, disable_run) + run_dispatcher_mock.assert_called_once_with(mock_data) @pytest.mark.asyncio async def test_submit_exception(mocker, client): """Test the submit endpoint.""" mock_data = json.dumps({}).encode("utf-8") - mocker.patch("covalent_dispatcher.run_dispatcher", side_effect=Exception("mock")) - response = client.post("/api/submit", data=mock_data) + mocker.patch("covalent_dispatcher.entry_point.make_dispatch", side_effect=Exception("mock")) + response = client.post("/api/v1/dispatch/submit", data=mock_data) assert response.status_code == 400 assert response.json()["detail"] == "Failed to submit workflow: mock" -@pytest.mark.asyncio -@pytest.mark.parametrize("is_pending", [True, False]) -async def test_redispatch(mocker, client, is_pending): - """Test the redispatch endpoint.""" - json_lattice = None - electron_updates = None - reuse_previous_results = False - mock_data = json.dumps( - { - "dispatch_id": DISPATCH_ID, - "json_lattice": json_lattice, - "electron_updates": electron_updates, - "reuse_previous_results": reuse_previous_results, - } - ).encode("utf-8") - run_redispatch_mock = mocker.patch( - "covalent_dispatcher.run_redispatch", return_value=DISPATCH_ID - ) - - response = client.post("/api/redispatch", data=mock_data, params={"is_pending": is_pending}) - assert response.json() == DISPATCH_ID - run_redispatch_mock.assert_called_once_with( - DISPATCH_ID, json_lattice, electron_updates, reuse_previous_results, is_pending - ) - - def test_cancel_dispatch(mocker, app, client): """ Test cancelling dispatch """ - mocker.patch("covalent_dispatcher.cancel_running_dispatch") + mocker.patch("covalent_dispatcher.entry_point.cancel_running_dispatch") response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) + "/api/v1/dispatch/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) ) assert response.json() == f"Dispatch {DISPATCH_ID} cancelled." @@ -148,32 +144,22 @@ def test_cancel_tasks(mocker, app, client): """ Test cancelling tasks within a lattice after dispatch """ - mocker.patch("covalent_dispatcher.cancel_running_dispatch") + mocker.patch("covalent_dispatcher.entry_point.cancel_running_dispatch") response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": [0, 1]}) + "/api/v1/dispatch/cancel", + data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": [0, 1]}), ) assert response.json() == f"Cancelled tasks [0, 1] in dispatch {DISPATCH_ID}." -@pytest.mark.asyncio -async def test_redispatch_exception(mocker, client): - """Test the redispatch endpoint.""" - response = client.post("/api/redispatch", data="bad data") - assert response.status_code == 400 - assert ( - response.json()["detail"] - == "Failed to redispatch workflow: Expecting value: line 1 column 1 (char 0)" - ) - - @pytest.mark.asyncio async def test_cancel(mocker, client): """Test the cancel endpoint.""" cancel_running_dispatch_mock = mocker.patch( - "covalent_dispatcher.cancel_running_dispatch", return_value=DISPATCH_ID + "covalent_dispatcher.entry_point.cancel_running_dispatch", return_value=DISPATCH_ID ) response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) + "/api/v1/dispatch/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) ) assert response.json() == f"Dispatch {DISPATCH_ID} cancelled." cancel_running_dispatch_mock.assert_called_once_with(DISPATCH_ID, []) @@ -183,69 +169,155 @@ async def test_cancel(mocker, client): async def test_cancel_exception(mocker, client): """Test the cancel endpoint.""" cancel_running_dispatch_mock = mocker.patch( - "covalent_dispatcher.cancel_running_dispatch", side_effect=Exception("mock") + "covalent_dispatcher.entry_point.cancel_running_dispatch", side_effect=Exception("mock") ) with pytest.raises(Exception): response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) + "/api/v1/dispatch/cancel", + data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}), ) assert response.status_code == 400 assert response.json()["detail"] == "Failed to cancel workflow: mock" cancel_running_dispatch_mock.assert_called_once_with(DISPATCH_ID, []) -def test_get_result(mocker, client, test_db_file): - """Test the get-result endpoint.""" - lattice = MockLattice( - status=str(Result.COMPLETED), - dispatch_id=DISPATCH_ID, +def test_db_path_get_config(mocker): + """Test that the db path is retrieved from the config.""" "" + get_config_mock = mocker.patch("covalent_dispatcher._db.dispatchdb.get_config") + + DispatchDB() + + get_config_mock.assert_called_once() + + +def test_register(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", return_value=mock_manifest + ) + resp = client.post("/api/v1/dispatch/register", data=mock_manifest.json()) + + assert resp.json() == json.loads(mock_manifest.json()) + mock_register_dispatch.assert_awaited_with(mock_manifest, None) + + +def test_register_exception(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", side_effect=RuntimeError() ) + resp = client.post("/api/v1/dispatch/register", data=mock_manifest.json()) + assert resp.status_code == 400 - with test_db_file.session() as session: - session.add(lattice) - session.commit() - mocker.patch("covalent_dispatcher._service.app._result_from", return_value={}) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}") - result = response.json() - assert result["id"] == DISPATCH_ID - assert result["status"] == Result.COMPLETED - os.remove("/tmp/testdb.sqlite") +def test_register_sublattice(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", return_value=mock_manifest + ) + resp = client.post( + "/api/v1/dispatch/register", + data=mock_manifest.json(), + params={"parent_dispatch_id": "parent_dispatch"}, + ) + + assert resp.json() == json.loads(mock_manifest.json()) + mock_register_dispatch.assert_awaited_with(mock_manifest, "parent_dispatch") -def test_get_result_503(mocker, client, test_db_file): - """Test the get-result endpoint.""" - lattice = MockLattice( - status=str(Result.NEW_OBJ), - dispatch_id=DISPATCH_ID, +def test_register_redispatch(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + return_value=mock_manifest, ) - with test_db_file.session() as session: - session.add(lattice) - session.commit() - mocker.patch("covalent_dispatcher._service.app._result_from", side_effect=FileNotFoundError()) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}?wait=True&status_only=True") - assert response.status_code == 503 - os.remove("/tmp/testdb.sqlite") + resp = client.post(f"/api/v1/dispatch/register/{dispatch_id}", data=mock_manifest.json()) + mock_register_redispatch.assert_awaited_with(mock_manifest, dispatch_id, False) + assert resp.json() == json.loads(mock_manifest.json()) -def test_get_result_dispatch_id_not_found(mocker, test_db_file, client): - """Test the get-result endpoint and that 404 is returned if the dispatch ID is not found in the database.""" - mocker.patch("covalent_dispatcher._service.app._result_from", return_value={}) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}") - assert response.status_code == 404 +def test_register_redispatch_reuse(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + return_value=mock_manifest, + ) + resp = client.post( + f"/api/v1/dispatch/register/{dispatch_id}", + data=mock_manifest.json(), + params={"reuse_previous_results": True}, + ) + mock_register_redispatch.assert_awaited_with(mock_manifest, dispatch_id, True) + assert resp.json() == json.loads(mock_manifest.json()) -def test_db_path_get_config(mocker): - """Test that the db path is retrieved from the config.""" "" - get_config_mock = mocker.patch("covalent_dispatcher._db.dispatchdb.get_config") +def test_register_redispatch_exception(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + side_effect=RuntimeError(), + ) + resp = client.post(f"/api/v1/dispatch/register/{dispatch_id}", data=mock_manifest.json()) + assert resp.status_code == 400 - DispatchDB() - get_config_mock.assert_called_once() +def test_start(mocker, app, client): + dispatch_id = "test_start" + mock_start = mocker.patch("covalent_dispatcher._service.app.dispatcher.start_dispatch") + mock_create_task = mocker.patch("asyncio.create_task") + resp = client.put(f"/api/v1/dispatch/start/{dispatch_id}") + assert resp.json() == dispatch_id + + +def test_export_result_nowait(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) + mocker.patch( + "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object + ) + mock_export = mocker.patch( + "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest + ) + resp = client.get(f"/api/v1/dispatch/export/{dispatch_id}") + assert resp.status_code == 200 + assert resp.json()["id"] == dispatch_id + assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) + assert resp.json()["result_export"] == json.loads(mock_manifest.json()) + + +def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) + mocker.patch( + "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object + ) + mock_export = mocker.patch( + "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest + ) + resp = client.get(f"/api/v1/dispatch/export/{dispatch_id}", params={"wait": True}) + assert resp.status_code == 503 + + +def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) + mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) + resp = client.get(f"/api/v1/dispatch/export/{dispatch_id}") + assert resp.status_code == 404 + + +def test_try_get_result_object(mocker, app, client, mock_manifest): + dispatch_id = "test_try_get_result_object" + mock_result_object = MagicMock() + mocker.patch( + "covalent_dispatcher._service.app.get_result_object", return_value=mock_result_object + ) + assert _try_get_result_object(dispatch_id) == mock_result_object + + +def test_try_get_result_object_not_found(mocker, app, client, mock_manifest): + dispatch_id = "test_try_get_result_object" + mock_result_object = MagicMock() + mocker.patch("covalent_dispatcher._service.app.get_result_object", side_effect=KeyError()) + assert _try_get_result_object(dispatch_id) is None diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py new file mode 100644 index 0000000000..4d3a16c2da --- /dev/null +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -0,0 +1,735 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Unit tests for the FastAPI asset endpoints""" + +import tempfile +from contextlib import contextmanager +from typing import Generator +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient +from sqlalchemy import Column, Integer, String, create_engine +from sqlalchemy.orm import Session, declarative_base, sessionmaker + +from covalent._workflow.transportable_object import TransportableObject +from covalent_dispatcher._service.assets import ( + _copy_file_obj, + _generate_file_slice, + _get_tobj_pickle_offsets, + _get_tobj_string_offsets, + get_cached_result_object, +) +from covalent_ui.app import fastapi_app as fast_app + +DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" + +INTERNAL_URI = "file:///tmp/object.pkl" + +# Mock SqlAlchemy models +MockBase = declarative_base() + + +class MockLattice(MockBase): + __tablename__ = "lattices" + id = Column(Integer, primary_key=True) + dispatch_id = Column(String(64), nullable=False) + status = Column(String(24), nullable=False) + + +class MockDataStore: + def __init__(self, db_URL): + self.db_URL = db_URL + self.engine = create_engine(self.db_URL) + self.Session = sessionmaker(self.engine) + + MockBase.metadata.create_all(self.engine) + + @contextmanager + def session(self) -> Generator[Session, None, None]: + with self.Session.begin() as session: + yield session + + +@pytest.fixture +def app(): + yield fast_app + + +@pytest.fixture +def client(): + with TestClient(fast_app) as c: + yield c + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + return MockDataStore(db_URL="sqlite+pysqlite:///:memory:") + + +@pytest.fixture +def mock_result_object(): + res_obj = MagicMock() + mock_node = MagicMock() + mock_asset = MagicMock() + mock_asset.internal_uri = INTERNAL_URI + + res_obj.get_asset = MagicMock(return_value=mock_asset) + res_obj.update_assets = MagicMock() + res_obj.lattice.get_asset = MagicMock(return_value=mock_asset) + res_obj.lattice.update_assets = MagicMock() + + res_obj.lattice.transport_graph.get_node = MagicMock(return_value=mock_node) + + mock_node.get_asset = MagicMock(return_value=mock_asset) + mock_node.update_assets = MagicMock() + + return res_obj + + +def test_get_node_asset(mocker, client, test_db, mock_result_object): + """ + Test get node asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + resp = client.get(f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_node_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get node asset + """ + + test_str = "test_get_node_asset_string_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + headers = {"Range": "bytes=0-6"} + + resp = client.get(f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}", headers=headers) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_node_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get node asset + """ + + test_str = "test_get_node_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + + params = {"representation": rep} + + resp = client.get(f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}", params=params) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_node_asset_bad_dispatch_id(mocker, client): + """ + Test get node asset + """ + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + resp = client.get(f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}") + assert resp.status_code == 400 + + +def test_get_lattice_asset(mocker, client, test_db, mock_result_object): + """ + Test get lattice asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + resp = client.get(f"/api/v1/assets/{dispatch_id}/lattice/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_lattice_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get lattice asset + """ + + test_str = "test_lattice_asset_byte_range" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + headers = {"Range": "bytes=0-6"} + resp = client.get(f"/api/v1/assets/{dispatch_id}/lattice/{key}", headers=headers) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_lattice_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get lattice asset + """ + + test_str = "test_get_lattice_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_rep" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + + params = {"representation": rep} + + resp = client.get(f"/api/v1/assets/{dispatch_id}/lattice/{key}", params=params) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_lattice_asset_bad_dispatch_id(mocker, client): + """ + Test get lattice asset + """ + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + + resp = client.get(f"/api/v1/assets/{dispatch_id}/lattice/{key}") + assert resp.status_code == 400 + + +def test_get_dispatch_asset(mocker, client, test_db, mock_result_object): + """ + Test get dispatch asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "result" + dispatch_id = "test_get_dispatch_asset" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + resp = client.get(f"/api/v1/assets/{dispatch_id}/dispatch/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_dispatch_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get dispatch asset + """ + + test_str = "test_dispatch_asset_byte_range" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "result" + dispatch_id = "test_get_dispatch_asset_byte_range" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + headers = {"Range": "bytes=0-6"} + resp = client.get(f"/api/v1/assets/{dispatch_id}/dispatch/{key}", headers=headers) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_dispatch_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get dispatch asset + """ + + test_str = "test_get_dispatch_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "result" + dispatch_id = "test_get_dispatch_asset_rep" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + + params = {"representation": rep} + + resp = client.get(f"/api/v1/assets/{dispatch_id}/dispatch/{key}", params=params) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_dispatch_asset_bad_dispatch_id(mocker, client): + """ + Test get dispatch asset + """ + + key = "result" + dispatch_id = "test_get_dispatch_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + + resp = client.get(f"/api/v1/assets/{dispatch_id}/dispatch/{key}") + assert resp.status_code == 400 + + +def test_post_node_asset(test_db, mocker, client, mock_result_object): + """ + Test post node asset + """ + + key = "function" + node_id = 0 + dispatch_id = "test_post_node_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + + mock_copy = mocker.patch("covalent_dispatcher._service.assets._copy_file_obj") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + headers = {"Digest": "sha=0af"} + resp = client.post( + f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}", files=files, headers=headers + ) + mock_node = mock_result_object.lattice.transport_graph.get_node(node_id) + mock_node.update_assets.assert_called() + assert resp.status_code == 200 + + mock_copy.assert_called() + + +def test_post_node_asset_bad_dispatch_id(mocker, client): + """ + Test post node asset + """ + key = "function" + node_id = 0 + dispatch_id = "test_post_node_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + resp = client.post(f"/api/v1/assets/{dispatch_id}/node/{node_id}/{key}", files=files) + + assert resp.status_code == 400 + + +def test_post_lattice_asset(mocker, client, test_db, mock_result_object): + """ + Test post lattice asset + """ + key = "workflow_function" + dispatch_id = "test_post_lattice_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + + mock_copy = mocker.patch("covalent_dispatcher._service.assets._copy_file_obj") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + resp = client.post(f"/api/v1/assets/{dispatch_id}/lattice/{key}", files=files) + mock_lattice = mock_result_object.lattice + mock_lattice.update_assets.assert_called() + assert resp.status_code == 200 + + mock_copy.assert_called() + + +def test_post_lattice_asset_bad_dispatch_id(mocker, client): + """ + Test post lattice asset + """ + key = "workflow_function" + dispatch_id = "test_post_lattice_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + resp = client.post(f"/api/v1/assets/{dispatch_id}/lattice/{key}", files=files) + + assert resp.status_code == 400 + + +def test_post_dispatch_asset(mocker, client, test_db, mock_result_object): + """ + Test post dispatch asset + """ + key = "result" + dispatch_id = "test_post_dispatch_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + + mock_copy = mocker.patch("covalent_dispatcher._service.assets._copy_file_obj") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + resp = client.post(f"/api/v1/assets/{dispatch_id}/dispatch/{key}", files=files) + mock_result_object.update_assets.assert_called() + assert resp.status_code == 200 + + mock_copy.assert_called() + + +def test_post_dispatch_asset_bad_dispatch_id(mocker, client): + """ + Test post dispatch asset + """ + key = "result" + dispatch_id = "test_post_dispatch_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + files = {"asset_file": open(writer.name, "rb")} + resp = client.post(f"/api/v1/assets/{dispatch_id}/dispatch/{key}", files=files) + + assert resp.status_code == 400 + + +def test_get_string_offsets(): + tobj = TransportableObject("test_get_string_offsets") + + data = tobj.serialize() + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + + start, end = _get_tobj_string_offsets(f"file://{write_file.name}") + + assert data[start:end].decode("utf-8") == tobj.object_string + + +def test_get_pickle_offsets(): + tobj = TransportableObject("test_get_pickle_offsets") + + data = tobj.serialize() + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + + start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}") + + assert data[start:].decode("utf-8") == tobj.get_serialized() + + +def test_generate_partial_file_slice(): + """Test generating slices of files.""" + + data = "test_generate_file_slice".encode("utf-8") + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + gen = _generate_file_slice(f"file://{write_file.name}", 1, 5, 2) + assert next(gen) == data[1:3] + assert next(gen) == data[3:5] + with pytest.raises(StopIteration): + next(gen) + + +def test_generate_whole_file_slice(): + """Test generating slices of files.""" + + data = "test_generate_file_slice".encode("utf-8") + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + gen = _generate_file_slice(f"file://{write_file.name}", 0, -1) + assert next(gen) == data + + +def test_get_cached_result_obj(mocker, test_db): + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch("covalent_dispatcher._service.assets.get_result_object", side_effect=KeyError()) + with pytest.raises(HTTPException): + get_cached_result_object("test_get_cached_result_obj") + + +def test_copy_file_obj(mocker): + with tempfile.NamedTemporaryFile("rb+") as src: + src.write("Hello".encode("utf-8")) + src.flush() + src.seek(0) + with tempfile.NamedTemporaryFile("r") as dest: + _copy_file_obj(src, f"file://{dest.name}") + assert dest.read() == "Hello" diff --git a/tests/covalent_dispatcher_tests/entry_point_test.py b/tests/covalent_dispatcher_tests/entry_point_test.py index 47fc1e780d..72ba1931fd 100644 --- a/tests/covalent_dispatcher_tests/entry_point_test.py +++ b/tests/covalent_dispatcher_tests/entry_point_test.py @@ -21,69 +21,98 @@ """Unit tests for the FastAPI app.""" +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from unittest.mock import MagicMock + import pytest -from covalent_dispatcher.entry_point import cancel_running_dispatch, run_dispatcher, run_redispatch +from covalent_dispatcher.entry_point import ( + cancel_running_dispatch, + register_dispatch, + register_redispatch, + run_dispatcher, + start_dispatch, +) DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" -class MockObject: - pass +@pytest.mark.asyncio +async def test_run_dispatcher(mocker): + mock_run_dispatch = mocker.patch("covalent_dispatcher._core.run_dispatch") + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.make_dispatch", return_value=DISPATCH_ID + ) + json_lattice = '{"workflow_function": "asdf"}' + dispatch_id = await run_dispatcher(json_lattice) + assert dispatch_id == DISPATCH_ID + mock_make_dispatch.assert_awaited_with(json_lattice) + mock_run_dispatch.assert_called_with(dispatch_id) -def mock_initialize_result_object(lattice): - result = MockObject() - result.dispatch_id = lattice["dispatch_id"] - return result +@pytest.mark.asyncio +async def test_cancel_running_dispatch(mocker): + mock_cancel_workflow = mocker.patch("covalent_dispatcher.entry_point.cancel_dispatch") + await cancel_running_dispatch(DISPATCH_ID) + mock_cancel_workflow.assert_awaited_once_with(DISPATCH_ID, []) @pytest.mark.asyncio -@pytest.mark.parametrize("disable_run", [True, False]) -async def test_run_dispatcher(mocker, disable_run): - """ - Test run_dispatcher is called with the - right arguments in different conditions - """ +async def test_start_dispatch_waits(mocker): + """Check that start_dispatch waits for any assets to be copied.""" + + dispatch_id = "test_start_dispatch_waits" + + def mock_copy(): + import time + + time.sleep(3) + + mock_futures = {} + ex = ThreadPoolExecutor(max_workers=1) + mocker.patch("covalent_dispatcher._core.copy_futures", mock_futures) mock_run_dispatch = mocker.patch("covalent_dispatcher._core.run_dispatch") - mock_make_dispatch = mocker.patch( - "covalent_dispatcher._core.make_dispatch", return_value=DISPATCH_ID - ) - json_lattice = '{"workflow_function": "asdf"}' - dispatch_id = await run_dispatcher(json_lattice, disable_run) - assert dispatch_id == DISPATCH_ID + fut = ex.submit(mock_copy) + mock_futures[dispatch_id] = fut + fut.add_done_callback(lambda x: mock_futures.pop(dispatch_id)) + + start_time = datetime.now() + await start_dispatch(dispatch_id) + end_time = datetime.now() - mock_make_dispatch.assert_called_with(json_lattice) - if not disable_run: - mock_run_dispatch.assert_called_with(dispatch_id) + assert (end_time - start_time).total_seconds() > 2 + + mock_run_dispatch.assert_called() @pytest.mark.asyncio -@pytest.mark.parametrize("is_pending", [True, False]) -async def test_run_redispatch(mocker, is_pending): - """ - Test the run_redispatch function is called - with the right arguments in differnet conditions - """ - - make_derived_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.make_derived_dispatch", return_value="mock-redispatch-id" - ) - run_dispatch_mock = mocker.patch("covalent_dispatcher._core.run_dispatch") - redispatch_id = await run_redispatch(DISPATCH_ID, "mock-json-lattice", {}, False, is_pending) +async def test_register_dispatch(mocker): + """Check register_dispatch""" - if not is_pending: - make_derived_dispatch_mock.assert_called_once_with( - DISPATCH_ID, "mock-json-lattice", {}, False - ) + mock_manifest = MagicMock() + + mock_importer = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_manifest", + return_value=mock_manifest, + ) - run_dispatch_mock.assert_called_once_with(redispatch_id) + assert await register_dispatch("manifest", "parent_dispatch_id") is mock_manifest + mock_importer.assert_awaited_with("manifest", "parent_dispatch_id", None) @pytest.mark.asyncio -async def test_cancel_running_dispatch(mocker): - mock_cancel_workflow = mocker.patch("covalent_dispatcher.entry_point.cancel_dispatch") - await cancel_running_dispatch(DISPATCH_ID) - mock_cancel_workflow.assert_awaited_once_with(DISPATCH_ID, []) +async def test_register_redispatch(mocker): + """Check register_dispatch""" + + mock_manifest = MagicMock() + + mock_importer = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_derived_manifest", + return_value=mock_manifest, + ) + + assert await register_redispatch("manifest", "parent_dispatch_id", True) is mock_manifest + mock_importer.assert_awaited_with("manifest", "parent_dispatch_id", True) diff --git a/tests/covalent_tests/__init__.py b/tests/covalent_tests/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_tests/__init__.py +++ b/tests/covalent_tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_tests/dispatcher_plugins/__init__.py b/tests/covalent_tests/dispatcher_plugins/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_tests/dispatcher_plugins/__init__.py +++ b/tests/covalent_tests/dispatcher_plugins/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_tests/dispatcher_plugins/local_test.py b/tests/covalent_tests/dispatcher_plugins/local_test.py index 3022ea8746..dcf196eb84 100644 --- a/tests/covalent_tests/dispatcher_plugins/local_test.py +++ b/tests/covalent_tests/dispatcher_plugins/local_test.py @@ -21,132 +21,142 @@ """Unit tests for local module in dispatcher_plugins.""" +import tempfile from unittest.mock import MagicMock import pytest from requests import Response -from requests.exceptions import HTTPError +from requests.exceptions import ConnectionError, HTTPError import covalent as ct -from covalent._dispatcher_plugins.local import LocalDispatcher, get_redispatch_request_body +from covalent._dispatcher_plugins.local import LocalDispatcher, get_redispatch_request_body_v2 +from covalent._results_manager.result import Result +from covalent._shared_files.utils import format_server_url -def test_get_redispatch_request_body_null_arguments(): - """Test the get request body function with null arguments.""" +def test_dispatching_a_non_lattice(): + """test dispatching a non-lattice""" @ct.electron - def identity(a): - return a + def task(a, b, c): + return a + b + c @ct.electron - def add(a, b): - return a + b + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - response = get_redispatch_request_body( - "mock-dispatch-id", - ) - assert response == { - "json_lattice": None, - "dispatch_id": "mock-dispatch-id", - "electron_updates": {}, - "reuse_previous_results": False, - } + with pytest.raises( + TypeError, match="Dispatcher expected a Lattice, received instead." + ): + LocalDispatcher.dispatch(workflow)(1, 2) -def test_get_redispatch_request_body_args_kwargs(mocker): - """Test the get request body function when args/kwargs is not null.""" - mock_electron = MagicMock() - get_result_mock = mocker.patch("covalent._dispatcher_plugins.local.get_result") - get_result_mock().lattice.serialize_to_json.return_value = "mock-json-lattice" +def test_dispatch_when_no_server_is_running(mocker): + """test dispatching a lattice when no server is running""" - response = get_redispatch_request_body( - "mock-dispatch-id", - new_args=[1, 2], - new_kwargs={"a": 1, "b": 2}, - replace_electrons={"mock-task-id": mock_electron}, - ) - assert response == { - "json_lattice": "mock-json-lattice", - "dispatch_id": "mock-dispatch-id", - "electron_updates": {"mock-task-id": mock_electron.electron_object.as_transportable_dict}, - "reuse_previous_results": False, - } - get_result_mock().lattice.build_graph.assert_called_once_with(*[1, 2], **{"a": 1, "b": 2}) + # the test suite is using another port, thus, with the dummy address below + # the covalent server is not running in some sense. + dummy_dispatcher_addr = "http://localhost:12345" + endpoint = "/api/v1/dispatch/register" + url = dummy_dispatcher_addr + endpoint + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + @ct.electron + def task(a, b, c): + return a + b + c -@pytest.mark.parametrize("is_pending", [True, False]) -@pytest.mark.parametrize( - "replace_electrons, expected_arg", - [(None, {}), ({"mock-electron-1": "mock-electron-2"}, {"mock-electron-1": "mock-electron-2"})], -) -def test_redispatch(mocker, replace_electrons, expected_arg, is_pending): - """Test the local re-dispatch function.""" + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value="mock-config") - requests_mock = mocker.patch("covalent._dispatcher_plugins.local.requests") - get_request_body_mock = mocker.patch( - "covalent._dispatcher_plugins.local.get_redispatch_request_body", - return_value={"mock-request-body"}, - ) + mock_print = mocker.patch("covalent._api.apiclient.print") - local_dispatcher = LocalDispatcher() - func = local_dispatcher.redispatch( - "mock-dispatch-id", replace_electrons=replace_electrons, is_pending=is_pending - ) - func() - requests_mock.post.assert_called_once_with( - "http://mock-config:mock-config/api/redispatch", - json={"mock-request-body"}, - params={"is_pending": is_pending}, - timeout=5, - ) - requests_mock.post().raise_for_status.assert_called_once() - requests_mock.post().content.decode().strip().replace.assert_called_once_with('"', "") + with pytest.raises(ConnectionError): + LocalDispatcher.dispatch(workflow, dispatcher_addr=dummy_dispatcher_addr)(1, 2) - get_request_body_mock.assert_called_once_with("mock-dispatch-id", (), {}, expected_arg, False) + mock_print.assert_called_once_with(message) -def test_redispatch_unreachable(mocker): - """Test the local re-dispatch function when the server is unreachable.""" +def test_dispatcher_dispatch_single(mocker): + """test dispatching a lattice with submit api""" - mock_dispatch_id = "mock-dispatch-id" - dummy_dispatcher_addr = "http://localhost:12345" + @ct.electron + def task(a, b, c): + return a + b + c - message = f"The Covalent server cannot be reached at {dummy_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - mock_print = mocker.patch("covalent._dispatcher_plugins.local.print") + # test when api raises an implicit error - LocalDispatcher.redispatch(mock_dispatch_id, dispatcher_addr=dummy_dispatcher_addr)() + dispatch_id = "test_dispatcher_dispatch_single" + # multistage = False + mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value=False) - mock_print.assert_called_once_with(message) + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) + + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + + mock_submit_callable.assert_called() + mock_start.assert_called() -def test_dispatching_a_non_lattice(): - """test dispatching a non-lattice""" + +def test_dispatcher_dispatch_multi(mocker): + """test dispatching a lattice with multistage api""" @ct.electron def task(a, b, c): return a + b + c - @ct.electron @ct.lattice def workflow(a, b): return task(a, b, c=4) - with pytest.raises( - TypeError, match="Dispatcher expected a Lattice, received instead." - ): - LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = "test_dispatcher_dispatch_multi" + # multistage = True + mocker.patch("covalent._shared_files.config.get_config", return_value=True) + mock_register_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register", + return_value=mock_register_callable, + ) -def test_dispatch_when_no_server_is_running(mocker): - """test dispatching a lattice when no server is running""" + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) - # the test suite is using another port, thus, with the dummy address below - # the covalent server is not running in some sense. - dummy_dispatcher_addr = "http://localhost:12345" + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) - message = f"The Covalent server cannot be reached at {dummy_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + + mock_submit_callable.assert_not_called() + mock_register_callable.assert_called() + mock_start.assert_called() + + +def test_dispatcher_dispatch_with_triggers(mocker): + """test dispatching a lattice with triggers""" @ct.electron def task(a, b, c): @@ -156,11 +166,32 @@ def task(a, b, c): def workflow(a, b): return task(a, b, c=4) - mock_print = mocker.patch("covalent._dispatcher_plugins.local.print") + dispatch_id = "test_dispatcher_dispatch_with_triggers" - LocalDispatcher.dispatch(workflow, dispatcher_addr=dummy_dispatcher_addr)(1, 2) + workflow.metadata["triggers"] = {"dir_trigger": {}} - mock_print.assert_called_once_with(message) + mock_register_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register", + return_value=mock_register_callable, + ) + + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) + + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) + + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + mock_reg_tr.assert_called() + mock_start.assert_not_called() def test_dispatcher_submit_api(mocker): @@ -180,10 +211,10 @@ def workflow(a, b): r.url = "http://dummy" r.reason = "dummy reason" - mocker.patch("covalent._dispatcher_plugins.local.requests.post", return_value=r) + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): - dispatch_id = LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = LocalDispatcher.submit(workflow)(1, 2) assert dispatch_id is None # test when api doesn't raise an implicit error @@ -192,7 +223,364 @@ def workflow(a, b): r.url = "http://dummy" r._content = b"abcde" - mocker.patch("covalent._dispatcher_plugins.local.requests.post", return_value=r) + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - dispatch_id = LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = LocalDispatcher.submit(workflow)(1, 2) assert dispatch_id == "abcde" + + +def test_dispatcher_start(mocker): + """Test starting a dispatch""" + + dispatch_id = "test_dispatcher_start" + r = Response() + r.status_code = 404 + r.url = "http://dummy" + r.reason = "dummy reason" + + mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + + with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): + LocalDispatcher.start(dispatch_id) + + # test when api doesn't raise an implicit error + r = Response() + r.status_code = 200 + r.url = "http://dummy" + r._content = dispatch_id.encode("utf-8") + + mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + + assert LocalDispatcher.start(dispatch_id) == dispatch_id + + +def test_register(mocker): + """test dispatching a lattice with register api""" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = "test_register" + + mock_upload = mocker.patch("covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets") + mock_prepare_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.prepare_manifest", + return_value=manifest, + ) + mock_register_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest" + ) + + dispatch_id = LocalDispatcher.register(workflow)(1, 2) + assert dispatch_id == "test_register" + mock_upload.assert_called() + + +def test_redispatch(mocker): + """test redispatching a lattice with register api""" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + dispatch_id = "test_register_redispatch" + manifest.metadata.dispatch_id = dispatch_id + parent_id = "parent_dispatch_id" + + mock_upload = mocker.patch("covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets") + mock_get_redispatch_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.get_redispatch_request_body_v2", return_value=manifest + ) + mock_register_derived_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_derived_manifest" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", + return_value="test_register_redispatch", + ) + + new_args = (1, 2) + new_kwargs = {} + redispatch_id = LocalDispatcher.redispatch( + dispatch_id=parent_id, replace_electrons={"f": "callable"}, reuse_previous_results=False + )(*new_args, **new_kwargs) + + assert dispatch_id == redispatch_id + mock_upload.assert_called() + + mock_start.assert_called_with(dispatch_id, format_server_url()) + + +def test_register_manifest(mocker): + """Test registering a dispatch manifest.""" + + dispatch_id = "test_register_manifest" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = dispatch_id + + r = Response() + r.status_code = 200 + r.json = MagicMock(return_value=manifest.dict()) + + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) + + mock_merge = mocker.patch( + "covalent._dispatcher_plugins.local.merge_response_manifest", return_value=manifest + ) + + return_manifest = LocalDispatcher.register_manifest(manifest) + assert return_manifest.metadata.dispatch_id == dispatch_id + mock_merge.assert_called_with(manifest, manifest) + + +def test_register_derived_manifest(mocker): + """Test registering a redispatch manifest.""" + + dispatch_id = "test_register_derived_manifest" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = dispatch_id + + r = Response() + r.status_code = 200 + r.json = MagicMock(return_value=manifest.dict()) + + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) + + mock_merge = mocker.patch( + "covalent._dispatcher_plugins.local.merge_response_manifest", return_value=manifest + ) + + return_manifest = LocalDispatcher.register_derived_manifest(manifest, "original_dispatch") + assert return_manifest.metadata.dispatch_id == dispatch_id + mock_merge.assert_called_with(manifest, manifest) + + +def test_upload_assets(mocker): + """Test uploading assets to HTTP endpoints""" + + dispatch_id = "test_upload_assets_http" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + num_assets = 0 + # Populate the lattice asset schemas with dummy URLs + for key, asset in manifest.lattice.assets: + num_assets += 1 + asset.remote_uri = f"http://localhost:48008/api/v1/assets/{dispatch_id}/lattice/dummy" + + endpoint = f"/api/v1/assets/{dispatch_id}/lattice/dummy" + r = Response() + r.status_code = 200 + mock_post = mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) + + LocalDispatcher.upload_assets(manifest) + + assert mock_post.call_count == num_assets + + +def test_get_redispatch_request_body_norebuild(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where the dispatch is to be retried with no + # changes to inputs or electrons. + + dispatch_id = "test_get_redispatch_request_body_norebuild" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [], {}, replace_electrons={} + ) + + assert redispatch_manifest is manifest + + +def test_get_redispatch_request_body_replace_electrons(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where electrons are to be replaced but lattice + # inputs stay the same. + + dispatch_id = "test_get_redispatch_request_body_replace_electrons" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.electron + def new_task(a, b, c): + return a * b * c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [], {}, replace_electrons={"task": new_task} + ) + + assert redispatch_manifest is manifest + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function_string") + mock_resmgr.download_lattice_asset.assert_any_call("inputs") + + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function_string") + mock_resmgr.load_lattice_asset.assert_any_call("inputs") + + +def test_get_redispatch_request_body_replace_inputs(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where only lattice + # inputs are changed. + + dispatch_id = "test_get_redispatch_request_body_replace_inputs" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [3, 4], {}, replace_electrons=None + ) + + assert redispatch_manifest is manifest + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function_string") + + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function_string") diff --git a/tests/covalent_tests/results_manager_tests/__init__.py b/tests/covalent_tests/results_manager_tests/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_tests/results_manager_tests/__init__.py +++ b/tests/covalent_tests/results_manager_tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 7134cb5551..176510af57 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -20,123 +20,286 @@ """Tests for results manager.""" -from http.client import HTTPMessage -from unittest.mock import ANY, MagicMock, Mock, call +import os +import tempfile +from datetime import datetime, timezone +from unittest.mock import MagicMock import pytest -import requests +from requests import Response -from covalent._results_manager import wait +import covalent as ct from covalent._results_manager.results_manager import ( - _get_result_from_dispatcher, + MissingLatticeRecordError, + Result, + ResultManager, + _get_result_export_from_dispatcher, cancel, + download_asset, get_result, ) -from covalent._shared_files.config import get_config +from covalent._serialize.result import serialize_result +from covalent._workflow.transportable_object import TransportableObject -def test_get_result_unreachable_dispatcher(mocker): - """ - Test that get_result returns None when - the dispatcher server is unreachable. - """ - mock_dispatch_id = "mock_dispatch_id" +def get_test_manifest(staging_dir): + @ct.electron + def identity(x): + return x - mocker.patch( - "covalent._results_manager.results_manager._get_result_from_dispatcher", - side_effect=requests.exceptions.ConnectionError, - ) + @ct.electron + def add(x, y): + return x + y - assert get_result(mock_dispatch_id) is None + @ct.lattice + def workflow(x, y): + res1 = identity(x) + res2 = identity(y) + return add(res1, res2) + workflow.build_graph(2, 3) + result_object = Result(workflow) + ts = datetime.now(timezone.utc) + result_object._start_time = ts + result_object._end_time = ts + result_object._result = TransportableObject(42) + result_object.lattice.transport_graph.set_node_value(0, "status", Result.COMPLETED) + result_object.lattice.transport_graph.set_node_value(0, "output", TransportableObject(2)) + manifest = serialize_result(result_object, staging_dir) -@pytest.mark.parametrize( - "dispatcher_addr", - [ - "http://" + get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")), - "http://localhost:48008", - ], -) -def test_get_result_from_dispatcher(mocker, dispatcher_addr): - retries = 10 - getconn_mock = mocker.patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") - mocker.patch("requests.Response.json", return_value=True) - headers = HTTPMessage() - headers.add_header("Retry-After", "2") - - mock_response = [Mock(status=503, msg=headers)] * (retries - 1) - mock_response.append(Mock(status=200, msg=HTTPMessage())) - getconn_mock.return_value.getresponse.side_effect = mock_response - dispatch_id = "9d1b308b-4763-4990-ae7f-6a6e36d35893" - _get_result_from_dispatcher( - dispatch_id, wait=wait.LONG, dispatcher_addr=dispatcher_addr, status_only=False - ) - assert ( - getconn_mock.return_value.request.mock_calls - == [ - call( - "GET", - f"/api/result/{dispatch_id}?wait=True&status_only=False", - body=None, - headers=ANY, - ), - ] - * retries - ) + # Swap asset uri and remote_uri to simulate an exported manifest + for key, asset in manifest.assets: + asset.remote_uri = asset.uri + asset.uri = None + for key, asset in manifest.lattice.assets: + asset.remote_uri = asset.uri + asset.uri = None -def test_get_result_from_dispatcher_unreachable(mocker): - """ - Test that _get_result_from_dispatcher raises an exception when - the dispatcher server is unreachable. - """ + for node in manifest.lattice.transport_graph.nodes: + for key, asset in node.assets: + asset.remote_uri = asset.uri + asset.uri = None - # TODO: Will need to edit this once `_get_result_from_dispatcher` is fixed - # to actually throw an exception when the dispatcher server is unreachable - # instead of just hanging. - - mock_dispatcher_addr = "mock_dispatcher_addr" - mock_dispatch_id = "mock_dispatch_id" - - message = f"The Covalent server cannot be reached at {mock_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - - mocker.patch("covalent._results_manager.results_manager.HTTPAdapter") - mock_session = mocker.patch("covalent._results_manager.results_manager.requests.Session") - mock_session.return_value.get.side_effect = requests.exceptions.ConnectionError - - mock_print = mocker.patch("covalent._results_manager.results_manager.print") - - with pytest.raises(requests.exceptions.ConnectionError): - _get_result_from_dispatcher( - mock_dispatch_id, wait=wait.LONG, dispatcher_addr=mock_dispatcher_addr - ) - - mock_print.assert_called_once_with(message) + return manifest def test_cancel_with_single_task_id(mocker): - mock_get_config = mocker.patch("covalent._results_manager.results_manager.get_config") mock_request_post = mocker.patch( - "covalent._results_manager.results_manager.requests.post", MagicMock() + "covalent._api.apiclient.requests.Session.post", ) cancel(dispatch_id="dispatch", task_ids=1) - assert mock_get_config.call_count == 2 mock_request_post.assert_called_once() mock_request_post.return_value.raise_for_status.assert_called_once() def test_cancel_with_multiple_task_ids(mocker): - mock_get_config = mocker.patch("covalent._results_manager.results_manager.get_config") mock_task_ids = [0, 1] mock_request_post = mocker.patch( - "covalent._results_manager.results_manager.requests.post", MagicMock() + "covalent._api.apiclient.requests.Session.post", ) cancel(dispatch_id="dispatch", task_ids=[1, 2, 3]) - assert mock_get_config.call_count == 2 mock_request_post.assert_called_once() mock_request_post.return_value.raise_for_status.assert_called_once() + + +def test_result_export(mocker): + with tempfile.TemporaryDirectory() as staging_dir: + test_manifest = get_test_manifest(staging_dir) + + dispatch_id = "test_result_export" + + mock_body = {"id": "test_result_export", "status": "COMPLETED"} + + mock_client = MagicMock() + mock_response = Response() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=mock_body) + + # mock_client.get = MagicMock(return_value=mock_response) + mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) + + # mocker.patch( + # "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client + # ) + + endpoint = f"/api/v1/dispatch/export/{dispatch_id}" + assert mock_body == _get_result_export_from_dispatcher( + dispatch_id, wait=False, status_only=True + ) + + +def test_result_manager_assets_local_copies(): + """Test downloading and loading assets using local asset uris.""" + dispatch_id = "test_result_manager" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + with tempfile.TemporaryDirectory() as results_dir: + rm = ResultManager(manifest, results_dir) + rm.download_lattice_asset("workflow_function") + rm.load_lattice_asset("workflow_function") + rm.download_result_asset("result") + rm.load_result_asset("result") + os.makedirs(f"{results_dir}/node_0") + rm.download_node_asset(0, "output") + rm.load_node_asset(0, "output") + + res_obj = rm.result_object + assert res_obj.lattice(3, 5) == 8 + assert res_obj.result == 42 + + output = res_obj.lattice.transport_graph.get_node_value(0, "output") + assert output.get_deserialized() == 2 + + +def test_result_manager_save_manifest(): + """Test saving and loading manifests""" + dispatch_id = "test_result_manager_save_load" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + with tempfile.TemporaryDirectory() as results_dir: + rm = ResultManager(manifest, results_dir) + rm.save() + path = os.path.join(results_dir, "manifest.json") + rm2 = ResultManager.load(path, results_dir) + assert rm2._results_dir == results_dir + assert rm2._manifest == rm._manifest + + +def test_get_result(mocker): + dispatch_id = "test_result_manager" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + + mock_result_export = { + "id": dispatch_id, + "status": "COMPLETED", + "result_export": manifest.dict(), + } + mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + return_value=mock_result_export, + ) + with tempfile.TemporaryDirectory() as results_dir: + res_obj = get_result(dispatch_id, results_dir=results_dir) + + assert res_obj.result == 42 + + +def test_get_result_sublattice(mocker): + dispatch_id = "test_result_manager_sublattice" + sub_dispatch_id = "test_result_manager_sublattice_sub" + + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + + node = manifest.lattice.transport_graph.nodes[0] + node.metadata.sub_dispatch_id = sub_dispatch_id + + with tempfile.TemporaryDirectory() as server_dir_sub: + # Sublattice manifest + sub_manifest = get_test_manifest(server_dir_sub) + + mock_result_export = { + "id": dispatch_id, + "status": "COMPLETED", + "result_export": manifest.dict(), + } + + mock_subresult_export = { + "id": sub_dispatch_id, + "status": "COMPLETED", + "result_export": sub_manifest.dict(), + } + + exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} + + def mock_get_export(dispatch_id, *args, **kwargs): + return exports[dispatch_id] + + mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + mock_get_export, + ) + with tempfile.TemporaryDirectory() as results_dir: + res_obj = get_result(dispatch_id, results_dir=results_dir) + + assert res_obj.result == 42 + tg = res_obj.lattice.transport_graph + for node_id in tg._graph.nodes: + if node_id == 0: + assert tg.get_node_value(node_id, "sub_dispatch_id") == sub_dispatch_id + assert tg.get_node_value(node_id, "sublattice_result") is not None + + else: + assert tg.get_node_value(1, "sublattice_result") is None + + +def test_get_result_404(mocker): + """Check exception handing for invalid dispatch ids.""" + + dispatch_id = "test_get_result_404" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 404 + + mock_client.get = MagicMock(return_value=mock_response) + + mocker.patch( + "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client + ) + + with pytest.raises(MissingLatticeRecordError): + get_result(dispatch_id) + + +def test_get_status_only(mocker): + """Check get_result when status_only=True""" + + dispatch_id = "test_get_result_st" + mock_get_result_export = mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + return_value={"id": dispatch_id, "status": "RUNNING"}, + ) + + status_report = get_result(dispatch_id, status_only=True) + assert status_report["status"] == "RUNNING" + + +def test_download_asset(mocker): + dispatch_id = "test_download_asset" + remote_uri = f"http://localhost:48008/api/v1/assets/dispatch/{dispatch_id}/result" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client.get = MagicMock(return_value=mock_response) + mocker.patch( + "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client + ) + + def mock_generator(): + yield "Hello".encode("utf-8") + + mock_response.iter_content = MagicMock(return_value=mock_generator()) + + with tempfile.NamedTemporaryFile() as local_file: + download_asset(remote_uri, local_file.name) + assert local_file.read().decode("utf-8") == "Hello" diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index 46add54a46..68e1f78974 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -65,7 +65,7 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.get_result") + mocker.patch("covalent_dispatcher._service.app.export_result") mock_fut_res = mock.Mock() mock_fut_res.result.return_value = mock_status @@ -102,29 +102,38 @@ def test_do_redispatch(mocker, use_internal_func, is_pending): with the right arguments in different conditions """ - base_trigger = BaseTrigger() - base_trigger.use_internal_funcs = use_internal_func - mock_redispatch_id = "test_dispatch_id" + mock_wrapper = mock.MagicMock(return_value=mock_redispatch_id) + mock_redispatch = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.redispatch", return_value=mock_wrapper + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=mock_redispatch_id + ) - if use_internal_func: - mocker.patch("covalent_dispatcher.run_redispatch") - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_redispatch_id - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res - ) - - redispatch_id = base_trigger._do_redispatch(is_pending) + base_trigger = BaseTrigger() + base_trigger.use_internal_funcs = use_internal_func - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() + # if use_internal_func: + # mocker.patch("covalent_dispatcher.entry_point.run_redispatch") + # mocker.patch("covalent_dispatcher.entry_point.start_dispatch") + # mock_fut_res = mock.Mock() + # mock_fut_res.result.return_value = mock_redispatch_id + # mock_run_coro = mocker.patch( + # "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res + # ) + # redispatch_id = base_trigger._do_redispatch(is_pending) + + # mock_run_coro.assert_called_once() + # mock_fut_res.result.assert_called_once() + # else: + redispatch_id = base_trigger._do_redispatch(is_pending) + + if is_pending: + mock_start.assert_called_once() + mock_wrapper.assert_not_called() else: - mock_redispatch = mocker.patch("covalent.redispatch")() - mock_redispatch.return_value = mock_redispatch_id - redispatch_id = base_trigger._do_redispatch(is_pending) - - mock_redispatch.assert_called_once() + mock_redispatch.assert_called() assert redispatch_id == mock_redispatch_id diff --git a/tests/covalent_tests/workflow/transport_graph_ops_test.py b/tests/covalent_tests/workflow/transport_graph_ops_test.py deleted file mode 100644 index 80cbbe3917..0000000000 --- a/tests/covalent_tests/workflow/transport_graph_ops_test.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the GNU Affero General Public License 3.0 (the "License"). -# A copy of the License may be obtained with this software package or at -# -# https://www.gnu.org/licenses/agpl-3.0.en.html -# -# Use of this file is prohibited except in compliance with the License. Any -# modifications or derivative works of this file must retain this copyright -# notice, and modified files must contain a notice indicating that they have -# been altered from the originals. -# -# Covalent is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. -# -# Relief from the License may be granted by purchasing a commercial license. - -"""Unit tests for transport graph operations module.""" - -import pytest - -from covalent._workflow.transport import _TransportGraph -from covalent._workflow.transport_graph_ops import TransportGraphOps - - -def add(x, y): - return x + y - - -def multiply(x, y): - return x * y - - -def identity(x): - return x - - -@pytest.fixture -def tg(): - """Transport graph operations fixture.""" - tg = _TransportGraph() - tg.add_node(name="add", function=add, metadata={"0-mock-key": "0-mock-value"}) - tg.add_node(name="multiply", function=multiply, metadata={"1-mock-key": "1-mock-value"}) - tg.add_node(name="identity", function=identity, metadata={"2-mock-key": "2-mock-value"}) - return tg - - -@pytest.fixture -def tg_2(): - """Transport graph operations fixture - different from tg.""" - tg_2 = _TransportGraph() - tg_2.add_node(name="not-add", function=add, metadata={"0- mock-key": "0-mock-value"}) - tg_2.add_node(name="multiply", function=multiply, metadata={"1- mock-key": "1-mock-value"}) - tg_2.add_node(name="identity", function=identity, metadata={"2- mock-key": "2-mock-value"}) - return tg_2 - - -@pytest.fixture -def tg_ops(tg): - """Transport graph operations fixture.""" - return TransportGraphOps(tg) - - -def test_init(tg): - """Test initialization of transport graph operations.""" - tg_ops = TransportGraphOps(tg) - assert tg_ops.tg == tg - assert tg_ops._status_map == {1: True, -1: False} - - -def test_flag_successors_no_successors(tg, tg_ops): - """Test flagging successors of a node.""" - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=0) - assert node_statuses == {0: -1, 1: 1, 2: 1} - - -@pytest.mark.parametrize( - "n_1,n_2,n_start,label,new_statuses", - [ - (0, 1, 0, "01", {0: -1, 1: -1, 2: 1}), - (1, 2, 0, "12", {0: -1, 1: 1, 2: 1}), - (1, 2, 1, "12", {0: 1, 1: -1, 2: -1}), - (1, 2, 2, "12", {0: 1, 1: 1, 2: -1}), - ], -) -def test_flag_successors_with_one_successors(tg, tg_ops, n_1, n_2, n_start, label, new_statuses): - """Test flagging successors of a node.""" - tg.add_edge(n_1, n_2, label) - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=n_start) - assert node_statuses == new_statuses - - -@pytest.mark.parametrize( - "n_1,n_2,n_3,n_4,label_1,label_2,n_start,new_statuses", - [ - (0, 1, 1, 2, "01", "12", 0, {0: -1, 1: -1, 2: -1}), - (0, 1, 0, 2, "01", "02", 0, {0: -1, 1: -1, 2: -1}), - (0, 1, 0, 2, "01", "12", 1, {0: 1, 1: -1, 2: 1}), - ], -) -def test_flag_successors_with_successors_3( - tg, tg_ops, n_1, n_2, n_3, n_4, label_1, n_start, label_2, new_statuses -): - """Test flagging successors of a node.""" - tg.add_edge(n_1, n_2, label_1) - tg.add_edge(n_3, n_4, label_2) - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=n_start) - assert node_statuses == new_statuses - - -def test_is_same_node_true(tg, tg_ops): - """Test the is same node method.""" - assert tg_ops.is_same_node(tg._graph, tg._graph, 0) is True - assert tg_ops.is_same_node(tg._graph, tg._graph, 1) is True - - -def test_is_same_node_false(tg, tg_ops): - """Test the is same node method.""" - tg_2 = _TransportGraph() - tg_2.add_node(name="multiply", function=add, metadata={"0- mock-key": "0-mock-value"}) - assert tg_ops.is_same_node(tg._graph, tg_2._graph, 0) is False - - -def test_is_same_edge_attributes_true(tg, tg_ops): - """Test the is same edge attributes method.""" - tg.add_edge(0, 1, edge_name="01", kwargs={"x": 1, "y": 2}) - assert tg_ops.is_same_edge_attributes(tg._graph, tg._graph, 0, 1) is True - - -def test_is_same_edge_attributes_false(tg, tg_ops): - """Test the is same edge attributes method.""" - tg.add_edge(0, 1, edge_name="01", kwargs={"x": 1, "y": 2}) - - tg_2 = _TransportGraph() - tg_2.add_node(name="add", function=add, metadata={"0- mock-key": "0-mock-value"}) - tg_2.add_node(name="multiply", function=multiply, metadata={"1- mock-key": "1-mock-value"}) - tg_2.add_node(name="identity", function=identity, metadata={"2- mock-key": "2-mock-value"}) - tg_2.add_edge(0, 1, edge_name="01", kwargs={"x": 1}) - - assert tg_ops.is_same_edge_attributes(tg._graph, tg_2._graph, 0, 1) is False - - -def test_copy_nodes_from(tg_ops): - """Test the node copying method.""" - - def replacement(x): - return x + 1 - - tg_new = _TransportGraph() - tg_new.add_node( - name="replacement", function=replacement, metadata={"0-mock-key": "0-mock-value"} - ) - tg_new.add_node(name="multiply", function=multiply, metadata={"1-mock-key": "1-mock-value"}) - tg_new.add_node( - name="replacement", function=replacement, metadata={"2-mock-key": "2-mock-value"} - ) - - tg_ops.copy_nodes_from(tg_new, [0, 2]) - tg_ops.tg._graph.nodes(data=True)[0]["name"] == tg_ops.tg._graph.nodes(data=True)[2][ - "name" - ] == "replacement" - tg_ops.tg._graph.nodes(data=True)[2]["name"] == "multiply" - - -def test_max_cbms(tg_ops): - """Test method for determining a largest cbms""" - import networkx as nx - - A = nx.MultiDiGraph() - B = nx.MultiDiGraph() - C = nx.MultiDiGraph() - D = nx.MultiDiGraph() - - # 0 5 6 - # / \ - # 1 2 - A.add_edge(0, 1) - A.add_edge(0, 2) - A.nodes[1]["color"] = "red" - A.add_node(5) - A.add_node(6) - - # 0 5 - # / \\ - # 1 2 - B.add_edge(0, 1) - B.add_edge(0, 2) - B.add_edge(0, 2) - B.nodes[1]["color"] = "black" - B.add_node(5) - - # 0 3 - # / \ / - # 1 2 - C.add_edge(0, 1) - C.add_edge(0, 2) - C.add_edge(3, 2) - - # 0 3 - # / \ / - # 1 2 - # / - # 4 - D.add_edge(0, 1) - D.add_edge(0, 2) - D.add_edge(3, 2) - D.add_edge(2, 4) - - A_node_status, B_node_status = tg_ops._max_cbms(A, B) - assert A_node_status == {0: True, 1: False, 2: False, 5: True, 6: False} - assert B_node_status == {0: True, 1: False, 2: False, 5: True} - - A_node_status, C_node_status = tg_ops._max_cbms(A, C) - assert A_node_status == {0: True, 1: False, 2: False, 5: False, 6: False} - assert C_node_status == {0: True, 1: False, 2: False, 3: False} - - C_node_status, D_node_status = tg_ops._max_cbms(C, D) - assert C_node_status == {0: True, 1: True, 2: True, 3: True} - assert D_node_status == {0: True, 1: True, 2: True, 3: True, 4: False} - - -def test_cmp_name_and_pval_true(tg, tg_ops): - """Test the name and parameter value comparison method.""" - assert tg_ops._cmp_name_and_pval(tg._graph, tg._graph, 0) is True - - -def test_cmp_name_and_pval_false(tg, tg_2, tg_ops): - """Test the name and parameter value comparison method.""" - assert tg_ops._cmp_name_and_pval(tg._graph, tg_2._graph, 0) is False - - -def test_get_reusable_nodes(mocker, tg, tg_2, tg_ops): - """Test the get reusable nodes method.""" - max_cbms_mock = mocker.patch( - "covalent._workflow.transport_graph_ops.TransportGraphOps._max_cbms", - return_value=({"mock-key-A": "mock-value-A"}, {"mock-key-B": "mock-value-B"}), - ) - reusable_nodes = tg_ops.get_reusable_nodes(tg_2) - assert reusable_nodes == ["mock-key-A"] - max_cbms_mock.assert_called_once() - - -def test_get_diff_nodes_integration_test(tg_2, tg_ops): - """Test the get reusable nodes method.""" - reusable_nodes = tg_ops.get_reusable_nodes(tg_2) - assert reusable_nodes == [1, 2] diff --git a/tests/functional_tests/__init__.py b/tests/functional_tests/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/functional_tests/__init__.py +++ b/tests/functional_tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/functional_tests/file_transfer_test.py b/tests/functional_tests/file_transfer_test.py index b5e8ce66e8..b785cb6a29 100644 --- a/tests/functional_tests/file_transfer_test.py +++ b/tests/functional_tests/file_transfer_test.py @@ -206,7 +206,7 @@ def test_local_file_transfer_transfer_from(tmp_path: Path, mocker): Popen.returncode = 0 mocker.patch("covalent._file_transfer.strategies.rsync_strategy.Popen", return_value=Popen) - ft = ct.fs.TransferFromRemote(str(source_file)) + ft = ct.fs.TransferFromRemote(str(source_file), strategy=ct.fs_strategies.Rsync()) @ct.electron(files=[ft]) def test_transfer(files=[]): @@ -243,7 +243,7 @@ def test_local_file_transfer_transfer_to(tmp_path: Path, mocker): Popen.returncode = 0 mocker.patch("covalent._file_transfer.strategies.rsync_strategy.Popen", return_value=Popen) - ft = ct.fs.TransferToRemote(str(dest_file)) + ft = ct.fs.TransferToRemote(str(dest_file), strategy=ct.fs_strategies.Rsync()) @ct.electron(files=[ft]) def test_transfer(files=[]): diff --git a/tests/functional_tests/local_executor_test.py b/tests/functional_tests/local_executor_test.py index a3361ea857..16322e59eb 100644 --- a/tests/functional_tests/local_executor_test.py +++ b/tests/functional_tests/local_executor_test.py @@ -20,6 +20,8 @@ import covalent as ct +import covalent._results_manager.results_manager as rm +from covalent._results_manager.result import Result def test_local_executor_returns_stdout_stderr(): @@ -45,3 +47,33 @@ def workflow(x): assert tg.get_node_value(0, "stdout") == "Hello\n" assert tg.get_node_value(0, "stderr") == "Error\n" assert tg.get_node_value(0, "output").get_deserialized() == 5 + + +def test_local_executor_build_sublattice_graph(): + """ + Check using local executor to build_sublattice_graph. + + This will exercise the /register endpoint for sublattices. + """ + + def add(a, b): + return a + b + + @ct.electron(executor="local") + def identity(a): + return a + + sublattice_add = ct.lattice(add) + + @ct.lattice(executor="local", workflow_executor="local") + def workflow(a, b): + res_1 = ct.electron(sublattice_add, executor="local")(a=a, b=b) + return identity(a=res_1) + + dispatch_id = ct.dispatch(workflow)(a=1, b=2) + workflow_result = rm.get_result(dispatch_id, wait=True) + + assert workflow_result.error == "" + assert workflow_result.status == Result.COMPLETED + assert workflow_result.result == 3 + assert workflow_result.get_node_result(node_id=0)["sublattice_result"].result == 3 diff --git a/tests/functional_tests/results_manager_test.py b/tests/functional_tests/results_manager_test.py new file mode 100644 index 0000000000..b21ec7c97c --- /dev/null +++ b/tests/functional_tests/results_manager_test.py @@ -0,0 +1,81 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Testing methods to retrieve workflow artifacts""" + +import pytest + +import covalent as ct +from covalent._shared_files.exceptions import MissingLatticeRecordError + + +def test_granular_get_result(): + def add(a, b): + return a + b + + @ct.electron + def identity(a): + return a + + sublattice_add = ct.lattice(add) + + @ct.lattice + def workflow(a, b): + res_1 = ct.electron(sublattice_add)(a=a, b=b) + return identity(a=res_1) + + dispatch_id = ct.dispatch(workflow)(a=1, b=2) + res_obj = ct.get_result( + dispatch_id, + wait=True, + workflow_output=False, + intermediate_outputs=False, + sublattice_results=False, + ) + + assert res_obj.result is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=False, sublattice_results=False + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"] is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=False, sublattice_results=True + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"].result == 3 + assert res_obj.get_node_result(0)["output"] is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=True, sublattice_results=False + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"] is None + assert res_obj.get_node_result(0)["output"].get_deserialized() == 3 + + +def test_get_result_nonexistent(): + with pytest.raises(MissingLatticeRecordError): + result_object = ct.get_result("nonexistent", wait=False) diff --git a/tests/functional_tests/triggers_test.py b/tests/functional_tests/triggers_test.py index b5dc414861..035f846ae4 100644 --- a/tests/functional_tests/triggers_test.py +++ b/tests/functional_tests/triggers_test.py @@ -67,7 +67,7 @@ def dir_workflow(): with open(read_file_path, "a") as f: f.write(f"{i}\n") - time.sleep(2) + time.sleep(5) with open(write_file_path, "r") as f: actual_sums = f.readlines() diff --git a/tests/functional_tests/workflow_cancellation_test.py b/tests/functional_tests/workflow_cancellation_test.py index 9cbaa6ac77..76ae2fa11c 100644 --- a/tests/functional_tests/workflow_cancellation_test.py +++ b/tests/functional_tests/workflow_cancellation_test.py @@ -52,6 +52,7 @@ def workflow(x): ct.cancel(dispatch_id) result = ct.get_result(dispatch_id, wait=True) + assert result.status == ct.status.CANCELLED rm._delete_result(dispatch_id) @@ -112,7 +113,7 @@ def workflow(x): return sub_workflow(3) dispatch_id = ct.dispatch(workflow)(3) - time.sleep(0.5) + time.sleep(1) ct.cancel(dispatch_id, task_ids=[0]) @@ -120,7 +121,6 @@ def workflow(x): tg = result.lattice.transport_graph sub_dispatch_id = tg.get_node_value(0, "sub_dispatch_id") - - print("Sublattice dispatch id:", sub_dispatch_id) - sub_res = ct.get_result(sub_dispatch_id) - assert sub_res.status == ct.status.CANCELLED + if sub_dispatch_id: + sub_res = ct.get_result(sub_dispatch_id) + assert sub_res.status == ct.status.CANCELLED diff --git a/tests/functional_tests/workflow_stack_test.py b/tests/functional_tests/workflow_stack_test.py index fd2b0c4385..e8cf0190e8 100644 --- a/tests/functional_tests/workflow_stack_test.py +++ b/tests/functional_tests/workflow_stack_test.py @@ -21,13 +21,14 @@ """Workflow stack testing of TransportGraph, Lattice and Electron classes.""" import os +import tempfile import pytest import covalent as ct +import covalent._dispatcher_plugins.local as local import covalent._results_manager.results_manager as rm from covalent._results_manager.result import Result -from covalent_dispatcher._db import update def construct_temp_cache_dir(): @@ -120,7 +121,7 @@ def workflow(a, b): dispatch_id = ct.dispatch(workflow)(a=1, b=2) workflow_result = rm.get_result(dispatch_id, wait=True) - assert workflow_result.error is None + assert workflow_result.error == "" assert workflow_result.status == Result.COMPLETED assert workflow_result.result == 3 assert workflow_result.get_node_result(node_id=0)["sublattice_result"].result == 3 @@ -176,11 +177,11 @@ def test_parallelization(): def heavy_function(a): import time - time.sleep(1) + time.sleep(10) return a @ct.lattice -def workflow(x=10): +def workflow(x=2): for i in range(x): heavy_function(a=i) return x @@ -263,7 +264,7 @@ def workflow(file_path): dispatch_id = ct.dispatch(workflow)(file_path=tmp_path) res = ct.get_result(dispatch_id, wait=True) - assert res.error is None + assert res.error == "" assert res.result == (True, "Hello") @@ -611,8 +612,9 @@ def workflow(a, /, b, *args, c, **kwargs): result = rm.get_result(dispatch_id, wait=True) rm._delete_result(dispatch_id) - assert ct.TransportableObject.deserialize_list(result.inputs["args"]) == [1, 2, 3, 4] - assert ct.TransportableObject.deserialize_dict(result.inputs["kwargs"]) == { + workflow_inputs = result.inputs.get_deserialized() + assert workflow_inputs["args"] == (1, 2, 3, 4) + assert workflow_inputs["kwargs"] == { "c": 5, "d": 6, "e": 7, @@ -683,7 +685,6 @@ def workflow(): dispatch_id = ct.dispatch(workflow)() result = ct.get_result(dispatch_id, wait=True) - update.persist(result) assert result.status == Result.COMPLETED assert ( @@ -931,3 +932,26 @@ def failing_workflow(x, y): assert int(result.result) == 1 assert result.status == "COMPLETED" assert result.get_node_result(0)["start_time"] == result.get_node_result(0)["end_time"] + + +def test_multistage_dispatch_with_pull_assets(): + """Test submitting a dispatch with assets to be pulled.""" + + @ct.electron + def task(x): + return x**3 + + @ct.lattice + def workflow(x): + return task(x) + + workflow.build_graph(5) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = local.LocalDispatcher.prepare_manifest(workflow, staging_dir) + return_manifest = local.LocalDispatcher.register_manifest(manifest, push_assets=False) + dispatch_id = return_manifest.metadata.dispatch_id + + local.LocalDispatcher.start(dispatch_id) + + res = rm.get_result(dispatch_id, wait=True) + assert res.result == 125 diff --git a/tests/load_tests/locustfiles/basic.py b/tests/load_tests/locustfiles/basic.py index 639b0be9f4..f9ef0c7fcf 100644 --- a/tests/load_tests/locustfiles/basic.py +++ b/tests/load_tests/locustfiles/basic.py @@ -48,18 +48,20 @@ def serialize_workflow(workflow, lattice_args): @task def submit_identity_workflow(self): - self.client.post("/api/submit", data=self.serialize_workflow(identity_workflow, [1])) + self.client.post( + "/api/v1/dispatch/submit", data=self.serialize_workflow(identity_workflow, [1]) + ) @task def submit_horizontal_workflow(self): self.client.post( - "/api/submit", + "/api/v1/dispatch/submit", data=self.serialize_workflow(horizontal_workflow, [random.randint(5, 10)]), ) @task def submit_add_multiply_workflow(self): self.client.post( - "/api/submit", + "/api/v1/dispatch/submit", data=self.serialize_workflow(add_multiply_workflow, [1, 2]), ) diff --git a/tests/load_tests/workflows/horizontal.py b/tests/load_tests/workflows/horizontal.py index 9bdad41352..d359bc67b4 100644 --- a/tests/load_tests/workflows/horizontal.py +++ b/tests/load_tests/workflows/horizontal.py @@ -1,3 +1,23 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + import covalent as ct diff --git a/tests/stress_tests/benchmarks/__init__.py b/tests/stress_tests/benchmarks/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/stress_tests/benchmarks/__init__.py +++ b/tests/stress_tests/benchmarks/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license.