diff --git a/covalent/executor/__init__.py b/covalent/executor/__init__.py index a8fb330f54..bf4bfc346c 100644 --- a/covalent/executor/__init__.py +++ b/covalent/executor/__init__.py @@ -32,7 +32,8 @@ from .._shared_files import logger from .._shared_files.config import get_config, update_config -from .base import BaseExecutor, wrapper_fn +from .base import BaseExecutor +from .utils.wrappers import wrapper_fn app_log = logger.app_log log_stack_info = logger.log_stack_info diff --git a/covalent/executor/base.py b/covalent/executor/base.py index e46f6590b1..dd7757e274 100644 --- a/covalent/executor/base.py +++ b/covalent/executor/base.py @@ -31,93 +31,24 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Iterable, - List, - Literal, - Optional, - Tuple, - Union, -) +from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Union import aiofiles from covalent._shared_files.exceptions import TaskCancelledError -from covalent._workflow.depscall import RESERVED_RETVAL_KEY__FILES +from covalent.executor.schemas import ResourceMap, TaskSpec from covalent.executor.utils import Signals from .._shared_files import TaskRuntimeError, logger from .._shared_files.context_managers import active_dispatch_info_manager -from .._shared_files.util_classes import RESULT_STATUS, DispatchInfo -from .._workflow.transport import TransportableObject +from .._shared_files.util_classes import RESULT_STATUS, DispatchInfo, Status +from .schemas import TaskUpdate app_log = logger.app_log log_stack_info = logger.log_stack_info TypeJSON = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] -def wrapper_fn( - function: TransportableObject, - call_before: List[Tuple[TransportableObject, TransportableObject, TransportableObject]], - call_after: List[Tuple[TransportableObject, TransportableObject, TransportableObject]], - *args, - **kwargs, -): - """Wrapper for serialized callable. - - Execute preparatory shell commands before deserializing and - running the callable. This is the actual function to be sent to - the various executors. - - """ - - cb_retvals = {} - for tup in call_before: - serialized_fn, serialized_args, serialized_kwargs, retval_key = tup - cb_fn = serialized_fn.get_deserialized() - cb_args = serialized_args.get_deserialized() - cb_kwargs = serialized_kwargs.get_deserialized() - retval = cb_fn(*cb_args, **cb_kwargs) - - # we always store cb_kwargs dict values as arrays to factor in non-unique values - if retval_key and retval_key in cb_retvals: - cb_retvals[retval_key].append(retval) - elif retval_key: - cb_retvals[retval_key] = [retval] - - # if cb_retvals key only contains one item this means it is a unique (non-repeated) retval key - # so we only return the first element however if it is a 'files' kwarg we always return as a list - cb_retvals = { - key: value[0] if len(value) == 1 and key != RESERVED_RETVAL_KEY__FILES else value - for key, value in cb_retvals.items() - } - - fn = function.get_deserialized() - - new_args = [arg.get_deserialized() for arg in args] - - new_kwargs = {k: v.get_deserialized() for k, v in kwargs.items()} - - # Inject return values into kwargs - for key, val in cb_retvals.items(): - new_kwargs[key] = val - - output = fn(*new_args, **new_kwargs) - - for tup in call_after: - serialized_fn, serialized_args, serialized_kwargs, retval_key = tup - ca_fn = serialized_fn.get_deserialized() - ca_args = serialized_args.get_deserialized() - ca_kwargs = serialized_kwargs.get_deserialized() - ca_fn(*ca_args, **ca_kwargs) - - return TransportableObject(output) - - class _AbstractBaseExecutor(ABC): """ Private parent class for BaseExecutor and AsyncBaseExecutor @@ -131,6 +62,8 @@ class _AbstractBaseExecutor(ABC): """ + SUPPORTS_MANAGED_EXECUTION = False + def __init__( self, log_stdout: str = "", @@ -289,6 +222,18 @@ def get_cancel_requested(self) -> bool: """ return self._notify_sync(Signals.GET, "cancel_requested") + def get_version_info(self) -> Dict: + """ + Query the database for the task's Python and Covalent version + + Arg: + dispatch_id: Dispatch ID of the lattice + + Returns: + {"python": python_version, "covalent": covalent_version} + """ + return self._notify_sync(Signals.GET, "version_info") + def set_job_handle(self, handle: TypeJSON) -> Any: """ Save the job_id/handle returned by the backend executing the task @@ -301,6 +246,32 @@ def set_job_handle(self, handle: TypeJSON) -> Any: """ return self._notify_sync(Signals.PUT, ("job_handle", json.dumps(handle))) + def _set_job_status_sync(self, status: Status) -> bool: + if self.validate_status(status): + return self._notify_sync(Signals.PUT, ("job_status", str(status))) + else: + return False + + def validate_status(self, status: Status) -> bool: + """Overridable filter""" + return True + + async def set_job_status(self, status: Status) -> bool: + """ + Sets the job state + + For use with send/receive API + + Return(s) + Whether the action succeeded + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._set_job_status_sync, + status, + ) + def write_streams_to_file( self, stream_strings: Iterable[str], @@ -436,7 +407,7 @@ def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict) raise NotImplementedError - def cancel(self, task_metadata: Dict, job_handle: Any) -> Literal[False]: + def cancel(self, task_metadata: Dict, job_handle: Any) -> bool: """ Method to cancel the job identified uniquely by the `job_handle` (base class) @@ -450,7 +421,7 @@ def cancel(self, task_metadata: Dict, job_handle: Any) -> Literal[False]: app_log.debug(f"Cancel not implemented for executor {type(self)}") return False - async def _cancel(self, task_metadata: Dict, job_handle: Any) -> Any: + async def _cancel(self, task_metadata: Dict, job_handle: Any) -> bool: """ Cancel the task in a non-blocking manner @@ -471,6 +442,77 @@ def teardown(self, task_metadata: Dict) -> Any: """Placeholder to run any executor specific cleanup/teardown actions""" pass + async def send( + self, + task_specs: List[Dict], + resources: ResourceMap, + task_group_metadata: Dict, + ): + """Submit a list of task references to the compute backend. + + Args: + task_specs: a list of TaskSpecs + resources: a ResourceMap mapping task assets to URIs + task_group_metadata: a dictionary of metadata for the task group. + Current keys are `dispatch_id`, `node_ids`, + and `task_group_id`. + + The return value of `send()` will be passed directly into `poll()`. + """ + # Schemas: + # + # Task spec: + # { + # "function_id": int, + # "args_ids": List[int], + # "kwargs_ids": Dict[str, int], + # "deps_id": str, + # "call_before_id": str, + # "call_after_id": str, + # } + + # resources: + # { + # "functions": Dict[int, str], + # "inputs": Dict[int, str], + # "deps": Dict[str, str] + # } + + # task_group_metadata: + # { + # "dispatch_id": str, + # "node_ids": List[int], + # "task_group_id": int, + # } + + # Assets are will be accessible by the compute backend + # at the provided URIs + + # Covalent will upload all assets before invoking `send()`. + + raise NotImplementedError + + async def poll(self, task_group_metadata: Dict, data: Any): + # To be run as a background task. A callback will be + # registered with the runner to invoke the receive() + + raise NotImplementedError + + async def receive( + self, + task_group_metadata: Dict, + data: Any, + ) -> List[TaskUpdate]: + # Returns (output_uri, stdout_uri, stderr_uri, + # exception_raised) + + # Job should have reached a terminal state by the time this is invoked. + + raise NotImplementedError + + def get_upload_uri(self, task_metadata: Dict, object_key: str): + return "" + class AsyncBaseExecutor(_AbstractBaseExecutor): """Async base executor class to be used for defining any executor @@ -570,6 +612,18 @@ async def get_cancel_requested(self) -> Any: """ return await self._notify_sync(Signals.GET, "cancel_requested") + async def get_version_info(self) -> Dict: + """ + Query the database for dispatch version metadata. + + Arg: + dispatch_id: Dispatch ID of the lattice + + Returns: + {"python": python_version, "covalent": covalent_version} + """ + return await self._notify_sync(Signals.GET, "version_info") + async def set_job_handle(self, handle: TypeJSON) -> Any: """ Save the job handle to database @@ -582,6 +636,24 @@ async def set_job_handle(self, handle: TypeJSON) -> Any: """ return await self._notify_sync(Signals.PUT, ("job_handle", json.dumps(handle))) + def validate_status(self, status: Status) -> bool: + """Overridable filter""" + return True + + async def set_job_status(self, status: Status) -> bool: + """ + Validates and sets the job state + + For use with send/receive API + + Return(s) + Whether the action succeeded + """ + if self.validate_status(status): + return await self._notify_sync(Signals.PUT, ("job_status", str(status))) + else: + return False + async def write_streams_to_file( self, stream_strings: Iterable[str], @@ -699,13 +771,13 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: raise NotImplementedError - async def cancel(self, task_metadata: Dict, job_handle: Any) -> Literal[False]: + async def cancel(self, task_metadata: Dict, job_handle: Any) -> bool: """ - Executor specific task cancellation method + Method to cancel the job identified uniquely by the `job_handle` (base class) Arg(s) - task_metadata: Metadata associated with the task to be cancelled - job_handle: Unique ID assigned to the job by the backend + task_metadata: Metadata of the task to be cancelled + job_handle: Unique ID of the job assigned by the backend Return(s) False by default @@ -713,17 +785,132 @@ async def cancel(self, task_metadata: Dict, job_handle: Any) -> Literal[False]: app_log.debug(f"Cancel not implemented for executor {type(self)}") return False - async def _cancel(self, task_metadata: Dict, job_handle: Any) -> Literal[False]: + async def _cancel(self, task_metadata: Dict, job_handle: Any) -> bool: """ - Cancel the task in a non-blocking manner and teardown the infrastructure + Cancel the task in a non-blocking manner Arg(s) - task_metadata: Metadata associated with the task to be cancelled - job_handle: Unique ID assigned to the job by the backend + task_metadata: Metadata of the task to be cancelled + job_handle: Unique ID of the job assigned by the backend Return(s) - Result from cancelling the task + Result of the task cancellation """ - cancel_result = await self.cancel(task_metadata, job_handle) - await self.teardown(task_metadata) - return cancel_result + return await self.cancel(task_metadata, job_handle) + + async def send( + self, + task_specs: List[TaskSpec], + resources: ResourceMap, + task_group_metadata: Dict, + ) -> Any: + """Submit a list of task references to the compute backend. + + Args: + task_specs: a list of TaskSpecs + resources: a ResourceMap mapping task assets to URIs + task_group_metadata: A dictionary of metadata for the task group. + Current keys are `dispatch_id`, `node_ids`, + and `task_group_id`. + + The return value of `send()` will be passed directly into `poll()`. + """ + # Schemas: + # + # Task spec: + # { + # "function_id": int, + # "args_ids": List[int], + # "kwargs_ids": Dict[str, int], + # "deps_id": str, + # "call_before_id": str, + # "call_after_id": str, + # } + + # resources: + # { + # "functions": Dict[int, str], + # "inputs": Dict[int, str], + # "deps": Dict[str, str] + # } + + # task_group_metadata: + # { + # "dispatch_id": str, + # "node_ids": List[int], + # "task_group_id": int, + # } + + # Assets are assumed to be accessible by the compute backend + # at the provided URIs + + # Covalent will upload all assets before invoking send(). + + raise NotImplementedError + + async def poll(self, task_group_metadata: Dict, data: Any) -> Any: + """Block until the job has reached a terminal state. + + Args: + task_group_metadata: A dictionary of metadata for the task group. + Current keys are `dispatch_id`, `node_ids`, + and `task_group_id`. + data: The return value of send(). + + The return value of `poll()` will be passed directly into receive(). + + Raise `NotImplementedError` to indicate that the compute backend + will notify the Covalent server asynchronously of job completion. + + """ + + raise NotImplementedError + + async def receive( + self, + task_group_metadata: Dict, + data: Any, + ) -> List[TaskUpdate]: + """Return a list of task updates. + + Each task must have reached a terminal state by the time this is invoked. + + Args: + task_group_metadata: A dictionary of metadata for the task group. + Current keys are `dispatch_id`, `node_ids`, + and `task_group_id`. + data: The return value of poll() or the request body of `/jobs/update`. + + Returns: + Returns a list of task results, each a TaskUpdate dataclass + of the form + + { + "dispatch_id": dispatch_id, + "node_id": node_id, + "status": status, + "assets": { + "output": { + "remote_uri": output_uri, + }, + "stdout": { + "remote_uri": stdout_uri, + }, + "stderr": { + "remote_uri": stderr_uri, + }, + }, + } + + corresponding to the node ids (task_ids) specified in the + `task_group_metadata`. This might be a subset of the node + ids in the originally submitted task group as jobs may + notify Covalent asynchronously of completed tasks before + the entire task group finishes running. + + """ + + raise NotImplementedError + + def get_upload_uri(self, task_group_metadata: Dict, object_key: str): + return "" diff --git a/covalent/executor/executor_plugins/dask.py b/covalent/executor/executor_plugins/dask.py index 99ef4e766a..3b894c7368 100644 --- a/covalent/executor/executor_plugins/dask.py +++ b/covalent/executor/executor_plugins/dask.py @@ -25,18 +25,26 @@ This is a plugin executor module; it is loaded if found and properly structured. """ +import asyncio +import json import os -from typing import Callable, Dict, List, Literal, Optional +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional from dask.distributed import CancelledError, Client, Future +from pydantic import BaseModel from covalent._shared_files import TaskRuntimeError, logger # Relative imports are not allowed in executor plugins from covalent._shared_files.config import get_config from covalent._shared_files.exceptions import TaskCancelledError +from covalent._shared_files.util_classes import RESULT_STATUS, Status +from covalent._shared_files.utils import format_server_url from covalent.executor.base import AsyncBaseExecutor +from covalent.executor.schemas import ResourceMap, TaskSpec, TaskUpdate from covalent.executor.utils.wrappers import io_wrapper as dask_wrapper +from covalent.executor.utils.wrappers import run_task_from_uris_alt # The plugin class name must be given by the executor_plugin_name attribute: EXECUTOR_PLUGIN_NAME = "DaskExecutor" @@ -58,21 +66,45 @@ "create_unique_workdir": False, } -# Temporary -_address_client_mapper = {} +# See https://github.com/dask/distributed/issues/5667 +_clients = {} + +# See +# https://stackoverflow.com/questions/62164283/why-do-my-dask-futures-get-stuck-in-pending-and-never-finish +_futures = {} + + +MANAGED_EXECUTION = os.environ.get("COVALENT_USE_OLD_DASK") != "1" + +# Dictionary to map Dask clients to their scheduler addresses +_address_client_map = {} +_lock = asyncio.Lock() + + +# Valid terminal statuses +class StatusEnum(str, Enum): + CANCELLED = str(RESULT_STATUS.CANCELLED) + COMPLETED = str(RESULT_STATUS.COMPLETED) + FAILED = str(RESULT_STATUS.FAILED) + READY = "READY" + + +class ReceiveModel(BaseModel): + status: StatusEnum class DaskExecutor(AsyncBaseExecutor): """ - Dask executor class that submits the input function to a running dask cluster. + Dask executor class that submits the input function to a running LOCAL dask cluster. """ + SUPPORTS_MANAGED_EXECUTION = MANAGED_EXECUTION + def __init__( self, scheduler_address: str = "", log_stdout: str = "stdout.log", log_stderr: str = "stderr.log", - conda_env: str = "", cache_dir: str = "", current_env_on_conda_fail: bool = False, workdir: str = "", @@ -105,18 +137,12 @@ def __init__( "No dask scheduler address found in config. Address must be set manually." ) - super().__init__( - log_stdout, - log_stderr, - cache_dir, - conda_env, - current_env_on_conda_fail, - ) - self.workdir = workdir self.create_unique_workdir = create_unique_workdir self.scheduler_address = scheduler_address + super().__init__(log_stdout, log_stderr, cache_dir=cache_dir) + async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict): """Submit the function and inputs to the dask cluster""" @@ -127,11 +153,11 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: dispatch_id = task_metadata["dispatch_id"] node_id = task_metadata["node_id"] - dask_client = _address_client_mapper.get(self.scheduler_address) + dask_client = _address_client_map.get(self.scheduler_address) if not dask_client: dask_client = Client(address=self.scheduler_address, asynchronous=True) - _address_client_mapper[self.scheduler_address] = dask_client + _address_client_map[self.scheduler_address] = dask_client await dask_client if self.create_unique_workdir: @@ -169,14 +195,158 @@ async def cancel(self, task_metadata: Dict, job_handle) -> Literal[True]: Return(s) True by default """ - dask_client = _address_client_mapper.get(self.scheduler_address) + dask_client = _address_client_map.get(self.scheduler_address) if not dask_client: dask_client = Client(address=self.scheduler_address, asynchronous=True) - _address_client_mapper[self.scheduler_address] = dask_client - await dask_client + await asyncio.wait_for(dask_client, timeout=5) fut: Future = Future(key=job_handle, client=dask_client) - await fut.cancel() + + # https://stackoverflow.com/questions/46278692/dask-distributed-how-to-cancel-tasks-submitted-with-fire-and-forget + await dask_client.cancel([fut], asynchronous=True, force=True) app_log.debug(f"Cancelled future with key {job_handle}") return True + + async def send( + self, + task_specs: List[TaskSpec], + resources: ResourceMap, + task_group_metadata: dict, + ): + # Assets are assumed to be accessible by the compute backend + # at the provided URIs + + # The Asset Manager is responsible for uploading all assets + # Returns a job handle (should be JSONable) + + dask_client = Client(address=self.scheduler_address, asynchronous=True) + await asyncio.wait_for(dask_client, timeout=5) + + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + gid = task_group_metadata["task_group_id"] + output_uris = [] + for node_id in task_ids: + result_uri = os.path.join(self.cache_dir, f"result_{dispatch_id}-{node_id}.pkl") + stdout_uri = os.path.join(self.cache_dir, f"stdout_{dispatch_id}-{node_id}.txt") + stderr_uri = os.path.join(self.cache_dir, f"stderr_{dispatch_id}-{node_id}.txt") + output_uris.append((result_uri, stdout_uri, stderr_uri)) + + server_url = format_server_url() + + key = f"dask_job_{dispatch_id}:{gid}" + + await self.set_job_handle(key) + + future = dask_client.submit( + run_task_from_uris_alt, + list(map(lambda t: t.dict(), task_specs)), + resources.dict(), + output_uris, + self.cache_dir, + task_group_metadata, + server_url, + key=key, + ) + + _clients[key] = dask_client + _futures[key] = future + + # def handle_cancelled(fut): + # import requests + + # app_log.debug(f"In done callback for {dispatch_id}:{gid}, future {fut}") + # if fut.cancelled(): + # for task_id in task_ids: + # url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/job" + # requests.put(url, json={"status": "CANCELLED"}) + + # future.add_done_callback(handle_cancelled) + + # fire_and_forget(future) + + # app_log.debug(f"Fire and forgetting task group {dispatch_id}:{gid}") + + return future.key + + async def poll(self, task_group_metadata: Dict, poll_data: Any): + fut = _futures.pop(poll_data) + app_log.debug(f"Future {fut}") + try: + await fut + except CancelledError: + raise TaskCancelledError() + + _clients.pop(poll_data) + + return {"status": StatusEnum.READY.value} + + # raise NotImplementedError + + async def receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate]: + # Job should have reached a terminal state by the time this is invoked. + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + + task_results = [] + + if not data: + terminal_status = RESULT_STATUS.CANCELLED + else: + received = ReceiveModel.parse_obj(data) + terminal_status = Status(received.status.value) + + for task_id in task_ids: + # TODO: Handle the case where the job was cancelled before the task started running + app_log.debug(f"Receive called for task {dispatch_id}:{task_id} with data {data}") + + if terminal_status == RESULT_STATUS.CANCELLED: + output_uri = "" + stdout_uri = "" + stderr_uri = "" + + else: + # terminal_status = data["status"] if data else RESULT_STATUS.CANCELLED + result_path = os.path.join(self.cache_dir, f"result-{dispatch_id}:{task_id}.json") + with open(result_path, "r") as f: + result_summary = json.load(f) + node_id = result_summary["node_id"] + output_uri = result_summary["output_uri"] + stdout_uri = result_summary["stdout_uri"] + stderr_uri = result_summary["stderr_uri"] + exception_raised = result_summary["exception_occurred"] + + terminal_status = ( + RESULT_STATUS.FAILED if exception_raised else RESULT_STATUS.COMPLETED + ) + + task_result = { + "dispatch_id": dispatch_id, + "node_id": task_id, + "status": terminal_status, + "assets": { + "output": { + "remote_uri": output_uri, + }, + "stdout": { + "remote_uri": stdout_uri, + }, + "stderr": { + "remote_uri": stderr_uri, + }, + }, + } + + task_results.append(TaskUpdate(**task_result)) + + app_log.debug(f"Returning results for tasks {dispatch_id}:{task_ids}") + return task_results + + def get_upload_uri(self, task_group_metadata: Dict, object_key: str): + dispatch_id = task_group_metadata["dispatch_id"] + task_group_id = task_group_metadata["task_group_id"] + + filename = f"asset_{dispatch_id}-{task_group_id}_{object_key}.pkl" + return os.path.join("file://", self.cache_dir, filename) + # return "" diff --git a/covalent/executor/executor_plugins/local.py b/covalent/executor/executor_plugins/local.py index 0a9ff3897a..1673bee637 100644 --- a/covalent/executor/executor_plugins/local.py +++ b/covalent/executor/executor_plugins/local.py @@ -24,19 +24,26 @@ This is a plugin executor module; it is loaded if found and properly structured. """ - +import asyncio import os from concurrent.futures import ProcessPoolExecutor +from enum import Enum from typing import Any, Callable, Dict, List, Optional -# Relative imports are not allowed in executor plugins +from pydantic import BaseModel + from covalent._shared_files import TaskCancelledError, TaskRuntimeError, logger from covalent._shared_files.config import get_config +from covalent._shared_files.util_classes import RESULT_STATUS, Status +from covalent._shared_files.utils import format_server_url from covalent.executor import BaseExecutor +# Relative imports are not allowed in executor plugins +from covalent.executor.schemas import ResourceMap, TaskSpec, TaskUpdate + # Store the wrapper function in an external module to avoid module # import errors during pickling -from covalent.executor.utils.wrappers import io_wrapper +from covalent.executor.utils.wrappers import io_wrapper, run_task_from_uris # The plugin class name must be given by the executor_plugin_name attribute: EXECUTOR_PLUGIN_NAME = "LocalExecutor" @@ -61,11 +68,27 @@ proc_pool = ProcessPoolExecutor() +# Valid terminal statuses +class StatusEnum(str, Enum): + CANCELLED = str(RESULT_STATUS.CANCELLED) + COMPLETED = str(RESULT_STATUS.COMPLETED) + FAILED = str(RESULT_STATUS.FAILED) + + +class ReceiveModel(BaseModel): + status: StatusEnum + + +MANAGED_EXECUTION = True + + class LocalExecutor(BaseExecutor): """ Local executor class that directly invokes the input function. """ + SUPPORTS_MANAGED_EXECUTION = MANAGED_EXECUTION + def __init__( self, workdir: str = "", create_unique_workdir: Optional[bool] = None, *args, **kwargs ) -> None: @@ -133,3 +156,129 @@ def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict) raise TaskRuntimeError(tb) return output + + def _send( + self, + task_specs: List[TaskSpec], + resources: ResourceMap, + task_group_metadata: dict, + ): + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + gid = task_group_metadata["task_group_id"] + output_uris = [] + for node_id in task_ids: + result_uri = os.path.join(self.cache_dir, f"result_{dispatch_id}-{node_id}.pkl") + stdout_uri = os.path.join(self.cache_dir, f"stdout_{dispatch_id}-{node_id}.txt") + stderr_uri = os.path.join(self.cache_dir, f"stderr_{dispatch_id}-{node_id}.txt") + output_uris.append((result_uri, stdout_uri, stderr_uri)) + # future = dask_client.submit(lambda x: x**3, 3) + + server_url = format_server_url() + + app_log.debug(f"Running task group {dispatch_id}:{task_ids}") + future = proc_pool.submit( + run_task_from_uris, + list(map(lambda t: t.dict(), task_specs)), + resources.dict(), + output_uris, + self.cache_dir, + task_group_metadata, + server_url, + ) + + def handle_cancelled(fut): + import requests + + app_log.debug(f"In done callback for {dispatch_id}:{gid}, future {fut}") + if fut.cancelled(): + for task_id in task_ids: + url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/job" + requests.put(url, json={"status": "CANCELLED"}) + + future.add_done_callback(handle_cancelled) + + return 42 + + def _receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate]: + # Returns (output_uri, stdout_uri, stderr_uri, + # exception_raised) + + # Job should have reached a terminal state by the time this is invoked. + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + + task_results = [] + + # if len(task_ids) > 1: + # raise RuntimeError("Task packing is not yet supported") + + for task_id in task_ids: + # Handle the case where the job was cancelled before the task started running + app_log.debug(f"Receive called for task {dispatch_id}:{task_id} with data {data}") + + if not data: + terminal_status = RESULT_STATUS.CANCELLED + else: + received = ReceiveModel.parse_obj(data) + terminal_status = Status(received.status.value) + + task_result = { + "dispatch_id": dispatch_id, + "node_id": task_id, + "status": terminal_status, + "assets": { + "output": { + "remote_uri": "", + }, + "stdout": { + "remote_uri": "", + }, + "stderr": { + "remote_uri": "", + }, + }, + } + + task_results.append(TaskUpdate(**task_result)) + + app_log.debug(f"Returning results for tasks {dispatch_id}:{task_ids}") + return task_results + + async def send( + self, + task_specs: List[TaskSpec], + resources: ResourceMap, + task_group_metadata: dict, + ): + # Assets are assumed to be accessible by the compute backend + # at the provided URIs + + # The Asset Manager is responsible for uploading all assets + # Returns a job handle (should be JSONable) + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + self._send, + task_specs, + resources, + task_group_metadata, + ) + + async def receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate]: + # Returns (output_uri, stdout_uri, stderr_uri, + # exception_raised) + + # Job should have reached a terminal state by the time this is invoked. + loop = asyncio.get_running_loop() + + return await loop.run_in_executor( + None, + self._receive, + task_group_metadata, + data, + ) + + def get_upload_uri(self, task_group_metadata: Dict, object_key: str): + return "" diff --git a/covalent/executor/schemas.py b/covalent/executor/schemas.py new file mode 100644 index 0000000000..8509404ad6 --- /dev/null +++ b/covalent/executor/schemas.py @@ -0,0 +1,122 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Types defining the runner-executor interface +""" + +from typing import Dict, List + +from pydantic import BaseModel, validator + +from covalent._shared_files.schemas.asset import AssetUpdate +from covalent._shared_files.util_classes import RESULT_STATUS, Status + + +class TaskUpdate(BaseModel): + """Represents a task status update. + + Attributes: + dispatch_id: The id of the dispatch. + node_id: The id of the task. + status: A Status dataclass representing the task's terminal status. + assets: A map from asset keys to AssetUpdate objects + """ + + dispatch_id: str + node_id: int + status: Status + assets: Dict[str, AssetUpdate] + + @validator("status") + def validate_status(cls, v): + if RESULT_STATUS.is_terminal(v): + return v + else: + raise ValueError(f"Illegal status update {v}") + + +class TaskSpec(BaseModel): + """An abstract description of a runnable task. + + Attributes: + function_id: The `node_id` of the function. + args_ids: The `node_id`s of the function's args + kwargs_ids: The `node_id`s of the function's kwargs {key: node_id} + deps_id: An opaque string representing the task's deps. + call_before_id: An opaque string representing the task's call_before. + call_after_id: An opaque string representing the task's call_before. + + The attribute values can be used in conjunction with a + `ResourceMap` to locate the actual resources in the compute + environment. + """ + + function_id: int + args_ids: List[int] + kwargs_ids: Dict[str, int] + deps_id: str + call_before_id: str + call_after_id: str + + +class ResourceMap(BaseModel): + + """Map resource identifiers to URIs. + + The resources may be loaded in the compute environment from these + URIs. + + Resource identifiers are the attribute values of TaskSpecs. For + instance, if ts is a `TaskSpec` and rm is the corresponding + `ResourceMap`, then + - the serialized function has URI `rm.functions[ts.function_id]` + - the serialized args have URIs `rm.inputs[ts.args_ids[i]]` + - the call_before has URI `rm.deps[ts.call_before_id]` + + Attributes: + functions: A map from node id to the corresponding URI. + inputs: A map from node id to the corresponding URI + deps: A map from deps resource ids to their corresponding URIs. + + """ + + # Map node_id to URI + functions: Dict[int, str] + + # Map node_id to URI + inputs: Dict[int, str] + + # Includes deps, call_before, call_after + deps: Dict[str, str] + + +class TaskGroup(BaseModel): + """Description of a group of runnable graph nodes. + + Attributes: + dispatch_id: The `dispatch_id`. + task_ids: The graph nodes comprising the task group. + task_group_id: The task group identifier. + """ + + dispatch_id: str + task_ids: List[int] + task_group_id: int diff --git a/covalent/executor/utils/__init__.py b/covalent/executor/utils/__init__.py index 45949792ea..459ef87e7b 100644 --- a/covalent/executor/utils/__init__.py +++ b/covalent/executor/utils/__init__.py @@ -18,4 +18,4 @@ # # Relief from the License may be granted by purchasing a commercial license. -from .wrappers import Signals +from .enums import Signals diff --git a/covalent/executor/utils/enums.py b/covalent/executor/utils/enums.py new file mode 100644 index 0000000000..40229c5085 --- /dev/null +++ b/covalent/executor/utils/enums.py @@ -0,0 +1,35 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Helper functions for the local executor +""" + +from enum import Enum + + +class Signals(Enum): + """ + Signals to enable communication between the executors and Covalent dispatcher + """ + + GET = 0 + PUT = 1 + EXIT = 2 diff --git a/covalent/executor/utils/serialize.py b/covalent/executor/utils/serialize.py new file mode 100644 index 0000000000..c5be088088 --- /dev/null +++ b/covalent/executor/utils/serialize.py @@ -0,0 +1,37 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Functions for serializing and deserializing assets +""" + +from typing import Any + +from ..._serialize.common import deserialize_asset, serialize_asset +from ..._serialize.electron import ASSET_TYPES + + +# Convenience functions for executor plugins +def serialize_node_asset(data: Any, key: str) -> bytes: + return serialize_asset(data, ASSET_TYPES[key]) + + +def deserialize_node_asset(data: bytes, key: str) -> Any: + return deserialize_asset(data, ASSET_TYPES[key]) diff --git a/covalent/executor/utils/wrappers.py b/covalent/executor/utils/wrappers.py index f38b8cc513..04f0766caf 100644 --- a/covalent/executor/utils/wrappers.py +++ b/covalent/executor/utils/wrappers.py @@ -23,22 +23,79 @@ """ import io +import json import os +import sys import traceback from contextlib import redirect_stderr, redirect_stdout -from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Tuple +import requests + +from covalent._workflow.depsbash import DepsBash +from covalent._workflow.depscall import RESERVED_RETVAL_KEY__FILES, DepsCall +from covalent._workflow.depspip import DepsPip +from covalent._workflow.transport import TransportableObject +from covalent.executor.utils.serialize import deserialize_node_asset, serialize_node_asset + + +def wrapper_fn( + function: TransportableObject, + call_before: List[Tuple[TransportableObject, TransportableObject, TransportableObject]], + call_after: List[Tuple[TransportableObject, TransportableObject, TransportableObject]], + *args, + **kwargs, +): + """Wrapper for serialized callable. + + Execute preparatory shell commands before deserializing and + running the callable. This is the actual function to be sent to + the various executors. -class Signals(Enum): - """ - Signals to enable communication between the executors and Covalent dispatcher """ - GET = 0 - PUT = 1 - EXIT = 2 + cb_retvals = {} + for tup in call_before: + serialized_fn, serialized_args, serialized_kwargs, retval_key = tup + cb_fn = serialized_fn.get_deserialized() + cb_args = serialized_args.get_deserialized() + cb_kwargs = serialized_kwargs.get_deserialized() + retval = cb_fn(*cb_args, **cb_kwargs) + + # we always store cb_kwargs dict values as arrays to factor in non-unique values + if retval_key and retval_key in cb_retvals: + cb_retvals[retval_key].append(retval) + elif retval_key: + cb_retvals[retval_key] = [retval] + + # if cb_retvals key only contains one item this means it is a unique (non-repeated) retval key + # so we only return the first element however if it is a 'files' kwarg we always return as a list + cb_retvals = { + key: value[0] if len(value) == 1 and key != RESERVED_RETVAL_KEY__FILES else value + for key, value in cb_retvals.items() + } + + fn = function.get_deserialized() + + new_args = [arg.get_deserialized() for arg in args] + + new_kwargs = {k: v.get_deserialized() for k, v in kwargs.items()} + + # Inject return values into kwargs + for key, val in cb_retvals.items(): + new_kwargs[key] = val + + output = fn(*new_args, **new_kwargs) + + for tup in call_after: + serialized_fn, serialized_args, serialized_kwargs, retval_key = tup + ca_fn = serialized_fn.get_deserialized() + ca_args = serialized_args.get_deserialized() + ca_kwargs = serialized_kwargs.get_deserialized() + ca_fn(*ca_args, **ca_kwargs) + + return TransportableObject(output) def io_wrapper( @@ -62,3 +119,362 @@ def io_wrapper( finally: os.chdir(current_dir) return output, stdout.getvalue(), stderr.getvalue(), tb + + +# Copied from runner.py +def _gather_deps(deps, call_before_objs_json, call_after_objs_json) -> Tuple[List, List]: + """Assemble deps for a node into the final call_before and call_after""" + + call_before = [] + call_after = [] + + # Rehydrate deps from JSON + if "bash" in deps: + dep = DepsBash() + dep.from_dict(deps["bash"]) + call_before.append(dep.apply()) + + if "pip" in deps: + dep = DepsPip() + dep.from_dict(deps["pip"]) + call_before.append(dep.apply()) + + for dep_json in call_before_objs_json: + dep = DepsCall() + dep.from_dict(dep_json) + call_before.append(dep.apply()) + + for dep_json in call_after_objs_json: + dep = DepsCall() + dep.from_dict(dep_json) + call_after.append(dep.apply()) + + return call_before, call_after + + +# Basic wrapper for executing a topologically sorted sequence of +# tasks. For the `task_specs` and `resources` schema see the comments +# for `AsyncBaseExecutor.send()`. + + +# URIs are just file paths +def run_task_from_uris( + task_specs: List[Dict], + resources: dict, + output_uris: List[Tuple[str, str, str]], + results_dir: str, + task_group_metadata: dict, + server_url: str, +): + prefix = "file://" + prefix_len = len(prefix) + + outputs = {} + results = [] + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + gid = task_group_metadata["task_group_id"] + + os.environ["COVALENT_DISPATCH_ID"] = dispatch_id + os.environ["COVALENT_DISPATCHER_URL"] = server_url + + for i, task in enumerate(task_specs): + result_uri, stdout_uri, stderr_uri = output_uris[i] + + with open(stdout_uri, "w") as stdout, open(stderr_uri, "w") as stderr: + with redirect_stdout(stdout), redirect_stderr(stderr): + try: + task_id = task["function_id"] + args_ids = task["args_ids"] + kwargs_ids = task["kwargs_ids"] + + function_uri = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/function" + + # Download function + resp = requests.get(function_uri, stream=True) + resp.raise_for_status() + serialized_fn = deserialize_node_asset(resp.content, "function") + + ser_args = [] + ser_kwargs = {} + + # Download args and kwargs + for node_id in args_ids: + url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{node_id}/assets/output" + resp = requests.get(url, stream=True) + resp.raise_for_status() + ser_args.append(deserialize_node_asset(resp.content, "output")) + + for k, node_id in kwargs_ids.items(): + url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{node_id}/assets/output" + resp = requests.get(url, stream=True) + resp.raise_for_status() + ser_kwargs[k] = deserialize_node_asset(resp.content, "output") + + # Download deps + deps_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/deps" + resp = requests.get(deps_url, stream=True) + resp.raise_for_status() + deps_json = deserialize_node_asset(resp.content, "deps") + + cb_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/call_before" + resp = requests.get(cb_url, stream=True) + resp.raise_for_status() + call_before_json = deserialize_node_asset(resp.content, "call_before") + + ca_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/call_after" + resp = requests.get(ca_url, stream=True) + resp.raise_for_status() + call_after_json = deserialize_node_asset(resp.content, "call_after") + + # Assemble and run the task + call_before, call_after = _gather_deps( + deps_json, call_before_json, call_after_json + ) + exception_occurred = False + + transportable_output = wrapper_fn( + serialized_fn, call_before, call_after, *ser_args, **ser_kwargs + ) + ser_output = serialize_node_asset(transportable_output, "output") + with open(result_uri, "wb") as f: + f.write(ser_output) + + outputs[task_id] = result_uri + + result_summary = { + "node_id": task_id, + "output_uri": result_uri, + "stdout_uri": stdout_uri, + "stderr_uri": stderr_uri, + "exception_occurred": exception_occurred, + } + + except Exception as ex: + exception_occurred = True + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + print(tb, file=sys.stderr) + result_uri = None + result_summary = { + "node_id": task_id, + "output_uri": result_uri, + "stdout_uri": stdout_uri, + "stderr_uri": stderr_uri, + "exception_occurred": exception_occurred, + } + + break + + finally: + # POST task artifacts + if result_uri: + upload_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/output" + with open(result_uri, "rb") as f: + files = {"asset_file": f} + requests.post(upload_url, files=files) + + sys.stdout.flush() + if stdout_uri: + upload_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/stdout" + with open(stdout_uri, "rb") as f: + files = {"asset_file": f} + requests.post(upload_url, files=files) + + sys.stderr.flush() + if stderr_uri: + upload_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/assets/stderr" + with open(stderr_uri, "rb") as f: + files = {"asset_file": f} + requests.post(upload_url, files=files) + + result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json") + + with open(result_path, "w") as f: + json.dump(result_summary, f) + + results.append(result_summary) + + # Notify Covalent that the task has terminated + terminal_status = "FAILED" if exception_occurred else "COMPLETED" + url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/job" + data = {"status": terminal_status} + requests.put(url, json=data) + + # Deal with any tasks that did not run + n = len(results) + if n < len(task_ids): + for i in range(n, len(task_ids)): + result_summary = { + "node_id": task_ids[i], + "output_uri": "", + "stdout_uri": "", + "stderr_uri": "", + "exception_occurred": True, + } + + results.append(result_summary) + + result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json") + + with open(result_path, "w") as f: + json.dump(result_summary, f) + + url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/{task_id}/job" + requests.put(url) + + +# URIs are just file paths +def run_task_from_uris_alt( + task_specs: List[Dict], + resources: dict, + output_uris: List[Tuple[str, str, str]], + results_dir: str, + task_group_metadata: dict, + server_url: str, +): + """Alternate form of run_task_from_uris for sync executors. + + This is appropriate for backends that cannot reach the Covalent + server. Covalent will push input assets to the executor's + persistent storage before invoking `Executor.send()` and pull output + artifacts after `Executor.receive()`. + + """ + + prefix = "file://" + prefix_len = len(prefix) + + outputs = {} + results = [] + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + gid = task_group_metadata["task_group_id"] + + os.environ["COVALENT_DISPATCH_ID"] = dispatch_id + os.environ["COVALENT_DISPATCHER_URL"] = server_url + + for i, task in enumerate(task_specs): + result_uri, stdout_uri, stderr_uri = output_uris[i] + + with open(stdout_uri, "w") as stdout, open(stderr_uri, "w") as stderr: + with redirect_stdout(stdout), redirect_stderr(stderr): + try: + task_id = task["function_id"] + args_ids = task["args_ids"] + kwargs_ids = task["kwargs_ids"] + + # Load function + function_uri = resources["functions"][task_id] + if function_uri.startswith(prefix): + function_uri = function_uri[prefix_len:] + + with open(function_uri, "rb") as f: + serialized_fn = deserialize_node_asset(f.read(), "function") + + # Load args and kwargs + ser_args = [] + ser_kwargs = {} + + args_uris = [resources["inputs"][index] for index in args_ids] + for uri in args_uris: + if uri.startswith(prefix): + uri = uri[prefix_len:] + with open(uri, "rb") as f: + ser_args.append(deserialize_node_asset(f.read(), "output")) + + kwargs_uris = {k: resources["inputs"][v] for k, v in kwargs_ids.items()} + for key, uri in kwargs_uris.items(): + if uri.startswith(prefix): + uri = uri[prefix_len:] + with open(uri, "rb") as f: + ser_kwargs[key] = deserialize_node_asset(f.read(), "output") + + # Load deps + deps_id = task["deps_id"] + deps_uri = resources["deps"][deps_id] + if deps_uri.startswith(prefix): + deps_uri = deps_uri[prefix_len:] + with open(deps_uri, "rb") as f: + deps_json = deserialize_node_asset(f.read(), "deps") + + call_before_id = task["call_before_id"] + call_before_uri = resources["deps"][call_before_id] + if call_before_uri.startswith(prefix): + call_before_uri = call_before_uri[prefix_len:] + with open(call_before_uri, "rb") as f: + call_before_json = deserialize_node_asset(f.read(), "call_before") + + call_after_id = task["call_after_id"] + call_after_uri = resources["deps"][call_after_id] + if call_after_uri.startswith(prefix): + call_after_uri = call_after_uri[prefix_len:] + with open(call_after_uri, "rb") as f: + call_after_json = deserialize_node_asset(f.read(), "call_after") + + # Assemble and invoke the task + call_before, call_after = _gather_deps( + deps_json, call_before_json, call_after_json + ) + exception_occurred = False + + transportable_output = wrapper_fn( + serialized_fn, call_before, call_after, *ser_args, **ser_kwargs + ) + ser_output = serialize_node_asset(transportable_output, "output") + + # Save output + with open(result_uri, "wb") as f: + f.write(ser_output) + + resources["inputs"][task_id] = result_uri + + result_summary = { + "node_id": task_id, + "output_uri": result_uri, + "stdout_uri": stdout_uri, + "stderr_uri": stderr_uri, + "exception_occurred": exception_occurred, + } + + except Exception as ex: + exception_occurred = True + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + print(tb, file=sys.stderr) + result_uri = None + result_summary = { + "node_id": task_id, + "output_uri": result_uri, + "stdout_uri": stdout_uri, + "stderr_uri": stderr_uri, + "exception_occurred": exception_occurred, + } + + break + + finally: + results.append(result_summary) + result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json") + + # Write the summary file containing the URIs for + # the serialized result, stdout, and stderr + with open(result_path, "w") as f: + json.dump(result_summary, f) + + # Deal with any tasks that did not run + n = len(results) + if n < len(task_ids): + for i in range(n, len(task_ids)): + result_summary = { + "node_id": task_ids[i], + "output_uri": "", + "stdout_uri": "", + "stderr_uri": "", + "exception_occurred": True, + } + + results.append(result_summary) + + result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json") + + with open(result_path, "w") as f: + json.dump(result_summary, f) diff --git a/covalent_dispatcher/_core/runner_ng.py b/covalent_dispatcher/_core/runner_ng.py new file mode 100644 index 0000000000..c9a8357591 --- /dev/null +++ b/covalent_dispatcher/_core/runner_ng.py @@ -0,0 +1,461 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Defines the core functionality of the runner +""" + +import asyncio +import traceback +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from typing import Any, Dict + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent._shared_files.exceptions import TaskCancelledError +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent.executor.base import AsyncBaseExecutor +from covalent.executor.schemas import ResourceMap, TaskSpec +from covalent.executor.utils import Signals + +from . import data_manager as datamgr +from . import runner as runner_legacy +from .data_modules import asset_manager as am +from .runner_modules import executor_proxy, jobs +from .runner_modules.cancel import cancel_tasks # nopycln: import +from .runner_modules.utils import get_executor + +app_log = logger.app_log +log_stack_info = logger.log_stack_info +debug_mode = get_config("sdk.log_level") == "debug" + +# Dedicated thread pool for invoking non-async Executor.cancel() +_cancel_threadpool = ThreadPoolExecutor() + +# Asyncio Queue +_job_events = None + +_job_event_listener = None + +_futures = set() + +# message format: +# Ready for retrieve result +# {"task_group_metadata": dict, "event": "READY"} +# +# Unable to retrieve result (e.g. credentials expired) +# +# {"task_group_metadata": dict, "event": "FAILED", "detail": str} + + +# Domain: runner +async def _submit_abstract_task_group( + dispatch_id: str, + task_group_id: int, + task_seq: list, + known_nodes: list, + executor: AsyncBaseExecutor, +) -> None: + # Task sequence of the form {"function_id": task_id, "args_ids": + # [node_ids], "kwargs_ids": {key: node_id}} + task_ids = [task["function_id"] for task in task_seq] + task_specs = [] + task_group_metadata = { + "dispatch_id": dispatch_id, + "task_group_id": task_group_id, + "node_ids": task_ids, + } + + try: + if not type(executor).SUPPORTS_MANAGED_EXECUTION: + raise NotImplementedError("Executor does not support managed execution") + + resources = {"functions": {}, "inputs": {}, "deps": {}} + + # Get upload URIs + for task_spec in task_seq: + task_id = task_spec["function_id"] + + function_uri = executor.get_upload_uri(task_group_metadata, f"function-{task_id}") + deps_uri = executor.get_upload_uri(task_group_metadata, f"deps-{task_id}") + call_before_uri = executor.get_upload_uri( + task_group_metadata, f"call_before-{task_id}" + ) + call_after_uri = executor.get_upload_uri(task_group_metadata, f"call_after-{task_id}") + + await am.upload_asset_for_nodes(dispatch_id, "function", {task_id: function_uri}) + await am.upload_asset_for_nodes(dispatch_id, "deps", {task_id: deps_uri}) + await am.upload_asset_for_nodes(dispatch_id, "call_before", {task_id: call_before_uri}) + await am.upload_asset_for_nodes(dispatch_id, "call_after", {task_id: call_after_uri}) + + deps_id = f"deps-{task_id}" + call_before_id = f"call_before-{task_id}" + call_after_id = f"call_after-{task_id}" + task_spec["deps_id"] = deps_id + task_spec["call_before_id"] = call_before_id + task_spec["call_after_id"] = call_after_id + + resources["functions"][task_id] = function_uri + resources["deps"][deps_id] = deps_uri + resources["deps"][call_before_id] = call_before_uri + resources["deps"][call_after_id] = call_after_uri + + task_specs.append(TaskSpec(**task_spec)) + + node_upload_uris = { + node_id: executor.get_upload_uri(task_group_metadata, f"node_{node_id}") + for node_id in known_nodes + } + resources["inputs"] = node_upload_uris + + app_log.debug( + f"Uploading known nodes {known_nodes} for task group {dispatch_id}:{task_group_id}" + ) + await am.upload_asset_for_nodes(dispatch_id, "output", node_upload_uris) + + ts = datetime.now(timezone.utc) + node_results = [ + datamgr.generate_node_result( + node_id=task_id, + start_time=ts, + status=RESULT_STATUS.RUNNING, + ) + for task_id in task_ids + ] + + # Use one proxy for the task group; handles the following requests: + # - check if the job has a pending cancel request + # - set the job handle + # - set job status + + # Watch the task group + fut = asyncio.create_task(executor_proxy.watch(dispatch_id, task_ids[0], executor)) + _futures.add(fut) + fut.add_done_callback(_futures.discard) + + send_retval = await executor.send( + task_specs, + ResourceMap(**resources), + task_group_metadata, + ) + + app_log.debug(f"Submitted task group {dispatch_id}:{task_group_id}") + + except TaskCancelledError: + app_log.debug(f"Task group {dispatch_id}:{task_group_id} cancelled") + + send_retval = None + + node_results = [ + datamgr.generate_node_result( + node_id=task_id, + end_time=datetime.now(timezone.utc), + status=RESULT_STATUS.CANCELLED, + ) + for task_id in task_ids + ] + + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.debug(f"Exception occurred when running task group {task_group_id}:") + app_log.debug(tb) + error_msg = tb if debug_mode else str(ex) + ts = datetime.now(timezone.utc) + + send_retval = None + + node_results = [ + datamgr.generate_node_result( + node_id=task_id, + end_time=datetime.now(timezone.utc), + status=RESULT_STATUS.FAILED, + error=error_msg, + ) + for task_id in task_ids + ] + + return node_results, send_retval + + +async def _get_task_result(task_group_metadata: Dict, data: Any): + """Retrieve task results from executor. + + Parameters: + task_group_metadata: metadata about the task group + data: task execution information (such as status) + + Both `task_group_metadata` and `data` will be passed directly to + Executor.receive(). + + """ + + dispatch_id = task_group_metadata["dispatch_id"] + task_ids = task_group_metadata["node_ids"] + gid = task_group_metadata["task_group_id"] + app_log.debug(f"Pulling job artifacts for task group {dispatch_id}:{gid}") + try: + executor_attrs = await datamgr.electron.get( + dispatch_id, gid, ["executor", "executor_data"] + ) + executor_name = executor_attrs["executor"] + executor_data = executor_attrs["executor_data"] + + executor = get_executor( + node_id=gid, + selected_executor=[executor_name, executor_data], + loop=asyncio.get_running_loop(), + pool=None, + ) + + # Expects a list of TaskUpdates + task_group_results = await executor.receive(task_group_metadata, data) + + node_results = [] + for task_result in task_group_results: + task_id = task_result.node_id + status = task_result.status + await am.download_assets_for_node(dispatch_id, task_id, task_result.assets) + + node_result = datamgr.generate_node_result( + node_id=task_id, end_time=datetime.now(timezone.utc), status=status + ) + node_results.append(node_result) + + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.debug(f"Exception occurred when receiving task group {gid}:") + app_log.debug(tb) + error_msg = tb if debug_mode else str(ex) + ts = datetime.now(timezone.utc) + + node_results = [ + datamgr.generate_node_result( + node_id=node_id, + end_time=ts, + status=RESULT_STATUS.FAILED, + error=error_msg, + ) + for node_id in task_ids + ] + + for node_result in node_results: + await datamgr.update_node_result(dispatch_id, node_result) + + +async def run_abstract_task_group( + dispatch_id: str, + task_group_id: int, + task_seq: list, + known_nodes: list, + selected_executor: Any, +) -> None: + executor = None + + try: + app_log.debug(f"Attempting to instantiate executor {selected_executor}") + task_ids = [task["function_id"] for task in task_seq] + app_log.debug(f"Running task group {dispatch_id}:{task_group_id}") + executor = get_executor( + node_id=task_group_id, + selected_executor=selected_executor, + loop=asyncio.get_running_loop(), + pool=None, + ) + + # Check if the job should be cancelled + if await jobs.get_cancel_requested(dispatch_id, task_ids[0]): + await jobs.put_job_status(dispatch_id, task_ids[0], RESULT_STATUS.CANCELLED) + + for task_id in task_ids: + task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} + app_log.debug(f"Refusing to execute cancelled task {dispatch_id}:{task_id}") + await mark_task_ready(task_metadata, None) + + return + + # Legacy runner doesn't yet support task packing + if not type(executor).SUPPORTS_MANAGED_EXECUTION: + if len(task_seq) == 1: + task_spec = task_seq[0] + node_id = task_spec["function_id"] + name = task_spec["name"] + abstract_inputs = { + "args": task_spec["args_ids"], + "kwargs": task_spec["kwargs_ids"], + } + app_log.debug(f"Reverting to legacy runner for task {task_group_id}") + coro = runner_legacy.run_abstract_task( + dispatch_id, + node_id, + name, + abstract_inputs, + selected_executor, + ) + fut = asyncio.create_task(coro) + _futures.add(fut) + fut.add_done_callback(_futures.discard) + return + + else: + raise RuntimeError("Task packing not supported by executor plugin") + node_results, send_retval = await _submit_abstract_task_group( + dispatch_id, + task_group_id, + task_seq, + known_nodes, + executor, + ) + + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.debug("Exception when trying to instantiate executor:") + app_log.debug(tb) + error_msg = tb if debug_mode else str(ex) + ts = datetime.now(timezone.utc) + + send_retval = None + node_results = [ + datamgr.generate_node_result( + node_id=node_id, + start_time=ts, + end_time=ts, + status=RESULT_STATUS.FAILED, + error=error_msg, + ) + for node_id in task_ids + ] + + for node_result in node_results: + await datamgr.update_node_result(dispatch_id, node_result) + + if node_results[0]["status"] == RESULT_STATUS.RUNNING: + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": task_ids, + "task_group_id": task_group_id, + } + await _poll_task_status(task_group_metadata, executor, send_retval) + + # Terminate proxy + if executor: + executor._notify(Signals.EXIT) + app_log.debug(f"Stopping proxy for task group {dispatch_id}:{task_group_id}") + + +async def _listen_for_job_events(): + app_log.debug("Starting event listener") + while True: + msg = await _job_events.get() + try: + event = msg["event"] + app_log.debug(f"Received job event {event}") + if event == "BYE": + app_log.debug("Terminating job event listener") + break + + # job has reached a terminal state + if event == "READY": + task_group_metadata = msg["task_group_metadata"] + detail = msg["detail"] + fut = asyncio.create_task(_get_task_result(task_group_metadata, detail)) + _futures.add(fut) + fut.add_done_callback(_futures.discard) + continue + + if event == "FAILED": + task_group_metadata = msg["task_group_metadata"] + dispatch_id = task_group_metadata["dispatch_id"] + gid = task_group_metadata["task_group_id"] + task_ids = task_group_metadata["node_ids"] + detail = msg["detail"] + ts = datetime.now(timezone.utc) + for task_id in task_ids: + node_result = datamgr.generate_node_result( + node_id=task_id, + end_time=ts, + status=RESULT_STATUS.FAILED, + error=detail, + ) + await datamgr.update_node_result(dispatch_id, node_result) + + except Exception as ex: + app_log.exception("Error reading message: {ex}") + + +async def _mark_ready(task_group_metadata: dict, detail: Any): + await _job_events.put( + {"task_group_metadata": task_group_metadata, "event": "READY", "detail": detail} + ) + + +async def _mark_failed(task_group_metadata: dict, detail: str): + await _job_events.put( + {"task_group_metadata": task_group_metadata, "event": "FAILED", "detail": detail} + ) + + +async def _poll_task_status( + task_group_metadata: Dict, executor: AsyncBaseExecutor, poll_data: Any +): + """Polls a group of tasks until it terminates. + + `poll_data` is the return value of `Executor.send()` and will be + passed directly to `Executor.poll()`. + + """ + # Return immediately if no polling logic (default return value is -1) + + dispatch_id = task_group_metadata["dispatch_id"] + task_group_id = task_group_metadata["task_group_id"] + task_ids = task_group_metadata["node_ids"] + + try: + app_log.debug(f"Polling status for task group {dispatch_id}:{task_group_id}") + receive_data = await executor.poll(task_group_metadata, poll_data) + await _mark_ready(task_group_metadata, receive_data) + + except NotImplementedError: + app_log.debug(f"Executor {executor.short_name()} is async.") + + except TaskCancelledError: + app_log.debug(f"Task group {dispatch_id}:{task_group_id} cancelled") + await _mark_ready(task_group_metadata, None) + + except Exception as ex: + task_group_id = task_group_metadata["task_group_id"] + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.debug(f"Exception occurred when polling task {task_group_id}:") + app_log.debug(tb) + error_msg = tb if debug_mode else str(ex) + await _mark_failed(task_group_metadata, error_msg) + + +async def mark_task_ready(task_metadata: dict, detail: Any): + dispatch_id = task_metadata["dispatch_id"] + node_id = task_metadata["node_id"] + gid = (await datamgr.electron.get(dispatch_id, node_id, ["task_group_id"]))["task_group_id"] + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": gid, + } + + await _mark_ready(task_group_metadata, detail) diff --git a/covalent_dispatcher/_service/runnersvc.py b/covalent_dispatcher/_service/runnersvc.py new file mode 100644 index 0000000000..fec773f0df --- /dev/null +++ b/covalent_dispatcher/_service/runnersvc.py @@ -0,0 +1,57 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Endpoints to update status of running tasks.""" + + +from fastapi import APIRouter, Request + +from covalent._shared_files import logger + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + + +@router.put("/dispatch/{dispatch_id}/electron/{node_id}/job") +async def update_task_status(dispatch_id: str, node_id: int, request: Request): + """Updates the status of a running task. + + The request JSON will be passed to the task executor plugin's + `receive()` method together with `dispatch_id` and `node_id`. + """ + + from .._core import runner_ng + + task_metadata = { + "dispatch_id": dispatch_id, + "node_id": node_id, + } + try: + detail = await request.json() + f"Task {task_metadata} marked ready with detail {detail}" + # detail = {"status": Status(status.value.upper())} + await runner_ng.mark_task_ready(task_metadata, detail) + # app_log.debug(f"Marked task {dispatch_id}:{node_id} with status {status}") + return f"Task {task_metadata} marked ready with detail {detail}" + except Exception as e: + app_log.debug(f"Exception in update_task_status: {e}") + raise diff --git a/tests/covalent_dispatcher_tests/_core/runner_ng_test.py b/tests/covalent_dispatcher_tests/_core/runner_ng_test.py new file mode 100644 index 0000000000..785f2355b6 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_ng_test.py @@ -0,0 +1,712 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +""" +Tests for the core functionality of the runner. +""" + + +import asyncio +from unittest.mock import AsyncMock + +import pytest +from sqlalchemy.pool import StaticPool + +import covalent as ct +from covalent._results_manager import Result +from covalent._shared_files.exceptions import TaskCancelledError +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent._workflow.lattice import Lattice +from covalent.executor.base import AsyncBaseExecutor +from covalent.executor.schemas import ResourceMap, TaskSpec, TaskUpdate +from covalent_dispatcher._core.runner_ng import ( + _get_task_result, + _listen_for_job_events, + _mark_failed, + _mark_ready, + _poll_task_status, + _submit_abstract_task_group, + run_abstract_task_group, +) +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import update +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +class MockExecutor(AsyncBaseExecutor): + async def run(self, function, args, kwargs, task_metadata): + pass + + +class MockManagedExecutor(AsyncBaseExecutor): + SUPPORTS_MANAGED_EXECUTION = True + + async def run(self, function, args, kwargs, task_metadata): + pass + + def get_upload_uri(self, task_metadata, object_key): + return f"file:///tmp/{object_key}" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id) + + +@pytest.mark.asyncio +async def test_submit_abstract_task_group(mocker): + import datetime + + me = MockManagedExecutor() + me.send = AsyncMock(return_value="42") + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.electron.get", + return_value={"executor": "managed_dask", "executor_data": {}}, + ) + + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + + ts = datetime.datetime.now() + + node_result = { + "node_id": 0, + "start_time": ts, + "status": "RUNNING", + } + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.generate_node_result", + return_value=node_result, + ) + + mock_upload = mocker.patch( + "covalent_dispatcher._core.data_modules.asset_manager.upload_asset_for_nodes", + ) + + dispatch_id = "dispatch" + name = "task" + abstract_inputs = {"args": [1], "kwargs": {"key": 2}} + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [0, 3], + "task_group_id": 0, + } + + mock_function_uri_0 = me.get_upload_uri(task_group_metadata, "function-0") + mock_deps_uri_0 = me.get_upload_uri(task_group_metadata, "deps-0") + mock_cb_uri_0 = me.get_upload_uri(task_group_metadata, "call_before-0") + mock_ca_uri_0 = me.get_upload_uri(task_group_metadata, "call_after-0") + + mock_function_uri_3 = me.get_upload_uri(task_group_metadata, "function-3") + mock_deps_uri_3 = me.get_upload_uri(task_group_metadata, "deps-3") + mock_cb_uri_3 = me.get_upload_uri(task_group_metadata, "call_before-3") + mock_ca_uri_3 = me.get_upload_uri(task_group_metadata, "call_after-3") + + mock_node_upload_uri_1 = me.get_upload_uri(task_group_metadata, "node_1") + mock_node_upload_uri_2 = me.get_upload_uri(task_group_metadata, "node_2") + + mock_function_id_0 = 0 + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + mock_deps_id_0 = "deps-0" + mock_cb_id_0 = "call_before-0" + mock_ca_id_0 = "call_after-0" + + mock_function_id_3 = 3 + mock_deps_id_3 = "deps-3" + mock_cb_id_3 = "call_before-3" + mock_ca_id_3 = "call_after-3" + + resources = { + "functions": { + 0: mock_function_uri_0, + 3: mock_function_uri_3, + }, + "inputs": { + 1: mock_node_upload_uri_1, + 2: mock_node_upload_uri_2, + }, + "deps": { + mock_deps_id_0: mock_deps_uri_0, + mock_cb_id_0: mock_cb_uri_0, + mock_ca_id_0: mock_ca_uri_0, + mock_deps_id_3: mock_deps_uri_3, + mock_cb_id_3: mock_cb_uri_3, + mock_ca_id_3: mock_ca_uri_3, + }, + } + + mock_task_spec_0 = { + "function_id": mock_function_id_0, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + "deps_id": mock_deps_id_0, + "call_before_id": mock_cb_id_0, + "call_after_id": mock_ca_id_0, + } + + mock_task_spec_3 = { + "function_id": mock_function_id_3, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + "deps_id": mock_deps_id_3, + "call_before_id": mock_cb_id_3, + "call_after_id": mock_ca_id_3, + } + + mock_task_0 = { + "function_id": mock_function_id_0, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + + mock_task_3 = { + "function_id": mock_function_id_3, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + + known_nodes = [1, 2] + + node_result, send_retval = await _submit_abstract_task_group( + dispatch_id=dispatch_id, + task_group_id=0, + task_seq=[mock_task_0, mock_task_3], + known_nodes=known_nodes, + executor=me, + ) + + mock_upload.assert_awaited() + + me.send.assert_awaited_with( + [TaskSpec(**mock_task_spec_0), TaskSpec(**mock_task_spec_3)], + ResourceMap(**resources), + task_group_metadata, + ) + assert send_retval == "42" + + +@pytest.mark.asyncio +async def test_submit_requires_opt_in(mocker): + """Checks submit rejects old-style executors""" + + import datetime + + me = MockExecutor() + me.send = AsyncMock(return_value="42") + ts = datetime.datetime.now() + dispatch_id = "dispatch" + task_id = 0 + name = "task" + abstract_inputs = {"args": [1], "kwargs": {"key": 2}} + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.electron.get", + return_value={"executor": "managed_dask", "executor_data": {}}, + ) + + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + + error_msg = str(NotImplementedError("Executor does not support managed execution")) + + node_result = { + "node_id": 0, + "end_time": ts, + "status": RESULT_STATUS.FAILED, + "error": error_msg, + } + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.generate_node_result", + return_value=node_result, + ) + mock_function_id = task_id + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + + mock_task = { + "function_id": mock_function_id, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + known_nodes = [1, 2] + + assert ([node_result], None) == await _submit_abstract_task_group( + dispatch_id, task_id, [mock_task], known_nodes, me + ) + + +@pytest.mark.asyncio +async def test_get_task_result(mocker): + import datetime + + me = MockManagedExecutor() + asset_uri = "file:///tmp/asset.pkl" + mock_task_result = { + "dispatch_id": "dispatch", + "node_id": 0, + "assets": { + "output": { + "remote_uri": asset_uri, + }, + "stdout": { + "remote_uri": asset_uri, + }, + "stderr": { + "remote_uri": asset_uri, + }, + }, + "status": RESULT_STATUS.COMPLETED, + } + me.receive = AsyncMock(return_value=[TaskUpdate(**mock_task_result)]) + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.electron.get", + return_value={"executor": "managed_dask", "executor_data": {}}, + ) + + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + + ts = datetime.datetime.now() + + node_result = { + "node_id": 0, + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + + expected_node_result = { + "node_id": 0, + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.generate_node_result", + return_value=node_result, + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.update_node_result", + ) + mock_download = mocker.patch( + "covalent_dispatcher._core.data_modules.asset_manager.download_assets_for_node", + ) + + dispatch_id = "dispatch" + task_id = 0 + name = "task" + + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [task_id], + "task_group_id": task_id, + } + job_meta = [{"job_handle": "42", "status": "COMPLETED"}] + + await _get_task_result(task_group_metadata, job_meta) + + me.receive.assert_awaited_with(task_group_metadata, job_meta) + + mock_update.assert_awaited_with(dispatch_id, expected_node_result) + mock_download.assert_awaited() + # Test exception during get + me.receive = AsyncMock(side_effect=RuntimeError()) + mock_update.reset_mock() + + await _get_task_result(task_group_metadata, job_meta) + mock_update.assert_awaited() + + +@pytest.mark.asyncio +async def test_poll_status(mocker): + me = MockManagedExecutor() + me.poll = AsyncMock(return_value=0) + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + mock_mark_ready = mocker.patch( + "covalent_dispatcher._core.runner_ng._mark_ready", + ) + + dispatch_id = "dispatch" + task_id = 1 + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [task_id], + "task_group_id": task_id, + } + + await _poll_task_status(task_group_metadata, me, "42") + + mock_mark_ready.assert_awaited_with(task_group_metadata, 0) + + me.poll = AsyncMock(side_effect=NotImplementedError()) + mock_mark_ready.reset_mock() + + await _poll_task_status(task_group_metadata, me, "42") + mock_mark_ready.assert_not_awaited() + + me.poll = AsyncMock(side_effect=RuntimeError()) + mock_mark_ready.reset_mock() + mock_mark_failed = mocker.patch( + "covalent_dispatcher._core.runner_ng._mark_failed", + ) + + await _poll_task_status(task_group_metadata, me, "42") + mock_mark_ready.assert_not_awaited() + mock_mark_failed.assert_awaited() + mock_mark_ready.reset_mock() + mock_mark_failed.reset_mock() + + me.poll = AsyncMock(side_effect=TaskCancelledError()) + + await _poll_task_status(task_group_metadata, me, "42") + mock_mark_ready.assert_awaited_with(task_group_metadata, None) + mock_mark_failed.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_event_listener(mocker): + import datetime + + ts = datetime.datetime.now() + node_result = { + "node_id": 0, + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.FAILED, + "error": "error", + } + + mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.generate_node_result", + return_value=node_result, + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.update_node_result", + ) + + mock_get = mocker.patch("covalent_dispatcher._core.runner_ng._get_task_result") + + task_group_metadata = {"dispatch_id": "dispatch", "task_group_id": 1, "node_ids": [1]} + + job_events = [{"event": "READY", "task_group_metadata": task_group_metadata}, {"event": "BYE"}] + + mock_event_queue = asyncio.Queue() + + mocker.patch( + "covalent_dispatcher._core.runner_ng._job_events", + mock_event_queue, + ) + fut = asyncio.create_task(_listen_for_job_events()) + await _mark_ready(task_group_metadata, "RUNNING") + await _mark_ready(task_group_metadata, "COMPLETED") + await mock_event_queue.put({"event": "BYE"}) + + await asyncio.wait_for(fut, 1) + + assert mock_get.call_count == 2 + + mock_get.reset_mock() + + fut = asyncio.create_task(_listen_for_job_events()) + + await _mark_failed(task_group_metadata, "error") + await mock_event_queue.put({"event": "BYE"}) + + await asyncio.wait_for(fut, 1) + + mock_update.assert_awaited_with(task_group_metadata["dispatch_id"], node_result) + + await mock_event_queue.put({"BAD_EVENT": "asdf"}) + await mock_event_queue.put({"event": "BYE"}) + mock_log = mocker.patch("covalent_dispatcher._core.runner_ng.app_log.exception") + + fut = asyncio.create_task(_listen_for_job_events()) + + await _mark_failed(task_group_metadata, "error") + await mock_event_queue.put({"event": "BYE"}) + + await asyncio.wait_for(fut, 1) + + +@pytest.mark.asyncio +async def test_run_abstract_task_group(mocker): + mock_listen = AsyncMock() + me = MockManagedExecutor() + me._init_runtime() + + me.poll = AsyncMock(return_value=0) + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.get_cancel_requested", return_value=False + ) + + mock_poll = mocker.patch( + "covalent_dispatcher._core.runner_ng._poll_task_status", + ) + + node_result = {"node_id": 0, "status": RESULT_STATUS.RUNNING} + + mock_submit = mocker.patch( + "covalent_dispatcher._core.runner_ng._submit_abstract_task_group", + return_value=([node_result], 42), + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.update_node_result", + ) + + dispatch_id = "dispatch" + node_id = 0 + node_name = "task" + abstract_inputs = {"args": [], "kwargs": {}} + selected_executor = ["local", {}] + mock_function_id = node_id + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + + mock_task = { + "function_id": mock_function_id, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + known_nodes = [1, 2] + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": node_id, + } + + await run_abstract_task_group( + dispatch_id, + node_id, + [mock_task], + known_nodes, + selected_executor, + ) + + mock_submit.assert_awaited() + mock_update.assert_awaited() + mock_poll.assert_awaited_with(task_group_metadata, me, 42) + + +@pytest.mark.asyncio +async def test_run_abstract_task_group_handles_old_execs(mocker): + mock_listen = AsyncMock() + me = MockExecutor() + me._init_runtime() + + mocker.patch( + "covalent_dispatcher._core.runner_ng.get_executor", + return_value=me, + ) + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.get_cancel_requested", return_value=False + ) + + mock_legacy_run = mocker.patch("covalent_dispatcher._core.runner.run_abstract_task") + + mock_submit = mocker.patch("covalent_dispatcher._core.runner_ng._submit_abstract_task_group") + + dispatch_id = "dispatch" + node_id = 0 + node_name = "task" + abstract_inputs = {"args": [], "kwargs": {}} + selected_executor = ["local", {}] + mock_function_id = node_id + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + + mock_task = { + "function_id": mock_function_id, + "name": node_name, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + known_nodes = [1, 2] + + await run_abstract_task_group( + dispatch_id, + node_id, + [mock_task], + known_nodes, + selected_executor, + ) + + mock_legacy_run.assert_called() + mock_submit.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_run_abstract_task_group_handles_bad_executors(mocker): + """Check handling of executors during get_executor""" + + from covalent._shared_files.defaults import sublattice_prefix + + mocker.patch("covalent_dispatcher._core.runner_ng.get_executor", side_effect=RuntimeError()) + + mock_update = mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.update_node_result", + ) + dispatch_id = "dispatch" + node_id = 0 + node_name = sublattice_prefix + abstract_inputs = {"args": [], "kwargs": {}} + selected_executor = ["local", {}] + mock_function_id = node_id + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + + mock_task = { + "function_id": mock_function_id, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + known_nodes = [1, 2] + + await run_abstract_task_group( + dispatch_id, + node_id, + [mock_task], + known_nodes, + selected_executor, + ) + + mock_update.assert_awaited() + + +@pytest.mark.asyncio +async def test_run_abstract_task_group_handles_cancelled_tasks(mocker): + """Check handling of cancelled tasks""" + + mock_listen = AsyncMock() + me = MockManagedExecutor() + me._init_runtime() + + me.poll = AsyncMock(return_value=0) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.get_cancel_requested", return_value=True + ) + + mock_jobs_put = mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.put_job_status", return_value=True + ) + + mock_submit = mocker.patch( + "covalent_dispatcher._core.runner_ng._submit_abstract_task_group", + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.runner_ng.datamgr.update_node_result", + ) + mock_mark_ready = mocker.patch( + "covalent_dispatcher._core.runner_ng.mark_task_ready", + ) + + dispatch_id = "dispatch" + node_id = 0 + node_name = "task" + abstract_inputs = {"args": [], "kwargs": {}} + selected_executor = ["local", {}] + mock_function_id = node_id + mock_args_ids = abstract_inputs["args"] + mock_kwargs_ids = abstract_inputs["kwargs"] + + mock_task = { + "function_id": mock_function_id, + "args_ids": mock_args_ids, + "kwargs_ids": mock_kwargs_ids, + } + known_nodes = [1, 2] + + await run_abstract_task_group( + dispatch_id, + node_id, + [mock_task], + known_nodes, + selected_executor, + ) + + mock_submit.assert_not_awaited() + mock_update.assert_not_awaited() + mock_mark_ready.assert_awaited() diff --git a/tests/covalent_dispatcher_tests/_service/runnersvc_test.py b/tests/covalent_dispatcher_tests/_service/runnersvc_test.py new file mode 100644 index 0000000000..558f9f205a --- /dev/null +++ b/tests/covalent_dispatcher_tests/_service/runnersvc_test.py @@ -0,0 +1,73 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. + +"""Unit tests for the FastAPI runner endpoints""" + + +import pytest +from fastapi.testclient import TestClient + +from covalent_ui.app import fastapi_app as fast_app + +DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" + + +@pytest.fixture +def app(): + yield fast_app + + +@pytest.fixture +def client(): + with TestClient(fast_app) as c: + yield c + + +def test_update_node_asset(mocker, client): + """ + Test update task status + """ + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + node_id = 0 + dispatch_id = "test_update_task_status" + body = {"status": "COMPLETED"} + + mock_mark_task_ready = mocker.patch("covalent_dispatcher._core.runner_ng.mark_task_ready") + resp = client.put(f"/api/v1/dispatch/{dispatch_id}/electron/{node_id}/job", json=body) + + task_metadata = {"dispatch_id": dispatch_id, "node_id": node_id} + mock_mark_task_ready.assert_called_with(task_metadata, body) + + +def test_update_node_asset_exception(mocker, client): + """ + Test update task status + """ + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + node_id = 0 + dispatch_id = "test_update_task_status" + body = {"status": "COMPLETED"} + + mock_mark_task_ready = mocker.patch( + "covalent_dispatcher._core.runner_ng.mark_task_ready", side_effect=KeyError() + ) + with pytest.raises(KeyError): + client.put(f"/api/v1/dispatch/{dispatch_id}/electron/{node_id}/job", json=body) diff --git a/tests/covalent_tests/executor/__init__.py b/tests/covalent_tests/executor/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_tests/executor/__init__.py +++ b/tests/covalent_tests/executor/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_tests/executor/base_test.py b/tests/covalent_tests/executor/base_test.py index 0aebc0a498..ec4f8655eb 100644 --- a/tests/covalent_tests/executor/base_test.py +++ b/tests/covalent_tests/executor/base_test.py @@ -32,7 +32,7 @@ from covalent._shared_files.exceptions import TaskCancelledError, TaskRuntimeError from covalent.executor import BaseExecutor, wrapper_fn from covalent.executor.base import AsyncBaseExecutor -from covalent.executor.utils.wrappers import Signals +from covalent.executor.utils.enums import Signals class MockExecutor(BaseExecutor): @@ -612,6 +612,21 @@ def test_base_executor_get_cancel_requested(mocker): mock_notify.assert_called_once() +def test_base_executor_get_version_info(mocker): + """ + Test executor invoking get cancel requested + """ + me = MockExecutor() + me._init_runtime() + recv_queue = me._recv_queue + mock_version_info = {"python": "3.8", "covalent": "1.0"} + recv_queue.put_nowait((True, mock_version_info)) + mock_notify = mocker.patch("covalent.executor.base.BaseExecutor._notify") + + assert me.get_version_info() == mock_version_info + mock_notify.assert_called_once() + + @pytest.mark.asyncio async def test_async_base_executor_get_cancel_requested(mocker): me = MockAsyncExecutor() @@ -622,6 +637,17 @@ async def test_async_base_executor_get_cancel_requested(mocker): assert await me.get_cancel_requested() is True +@pytest.mark.asyncio +async def test_async_base_executor_get_version_info(mocker): + me = MockAsyncExecutor() + me._init_runtime() + send_queue = me._send_queue + recv_queue = me._recv_queue + mock_version_info = {"python": "3.8", "covalent": "1.0"} + recv_queue.put_nowait((True, mock_version_info)) + assert await me.get_version_info() == mock_version_info + + def test_base_executor_set_job_handle(mocker): me = MockExecutor() me._init_runtime() @@ -803,5 +829,5 @@ async def test_base_async_executor_private_cancel(mocker): cancel_result = await me._cancel(task_metadata=task_metadata, job_handle=job_handle) mock_app_log.assert_called_with(f"Cancel not implemented for executor {type(me)}") - me.teardown.assert_awaited_with(task_metadata) + # me.teardown.assert_awaited_with(task_metadata) assert cancel_result is False diff --git a/tests/covalent_tests/executor/executor_plugins/__init__.py b/tests/covalent_tests/executor/executor_plugins/__init__.py index e69de29bb2..523f776226 100644 --- a/tests/covalent_tests/executor/executor_plugins/__init__.py +++ b/tests/covalent_tests/executor/executor_plugins/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the GNU Affero General Public License 3.0 (the "License"). +# A copy of the License may be obtained with this software package or at +# +# https://www.gnu.org/licenses/agpl-3.0.en.html +# +# Use of this file is prohibited except in compliance with the License. Any +# modifications or derivative works of this file must retain this copyright +# notice, and modified files must contain a notice indicating that they have +# been altered from the originals. +# +# Covalent is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details. +# +# Relief from the License may be granted by purchasing a commercial license. diff --git a/tests/covalent_tests/executor/executor_plugins/dask_test.py b/tests/covalent_tests/executor/executor_plugins/dask_test.py index 979249bfb5..1545a9f59a 100644 --- a/tests/covalent_tests/executor/executor_plugins/dask_test.py +++ b/tests/covalent_tests/executor/executor_plugins/dask_test.py @@ -21,6 +21,7 @@ """Tests for Covalent dask executor.""" import asyncio +import json import os import tempfile from unittest.mock import AsyncMock @@ -30,7 +31,15 @@ import covalent as ct from covalent._shared_files import TaskRuntimeError from covalent._shared_files.exceptions import TaskCancelledError -from covalent.executor.executor_plugins.dask import _EXECUTOR_PLUGIN_DEFAULTS, DaskExecutor +from covalent._workflow.transportable_object import TransportableObject +from covalent.executor.executor_plugins.dask import ( + _EXECUTOR_PLUGIN_DEFAULTS, + DaskExecutor, + ResourceMap, + TaskSpec, + run_task_from_uris_alt, +) +from covalent.executor.utils.serialize import serialize_node_asset def test_dask_executor_init(mocker): @@ -288,3 +297,320 @@ def test_dask_task_cancel(mocker): result = asyncio.run(dask_exec.cancel(task_metadata, job_handle)) mock_app_log.assert_called_with(f"Cancelled future with key {job_handle}") assert result is True + + +def test_dask_send_poll_receive(mocker): + """Test running a task using send + poll + receive.""" + from dask.distributed import LocalCluster + + from covalent.executor import DaskExecutor + + cluster = LocalCluster() + dask_exec = DaskExecutor(cluster.scheduler_address) + mock_get_cancel_requested = mocker.patch.object( + dask_exec, "get_cancel_requested", AsyncMock(return_value=False) + ) + mock_set_job_handle = mocker.patch.object(dask_exec, "set_job_handle", AsyncMock()) + + def task(x, y): + return x + y + + dispatch_id = "test_dask_send_receive" + node_id = 0 + task_group_id = 0 + + x = TransportableObject(1) + y = TransportableObject(2) + deps = {} + call_before = [] + call_after = [] + + ser_task = serialize_node_asset(TransportableObject(task), "function") + ser_deps = serialize_node_asset(deps, "deps") + ser_cb = serialize_node_asset(deps, "call_before") + ser_ca = serialize_node_asset(deps, "call_after") + ser_x = serialize_node_asset(x, "output") + ser_y = serialize_node_asset(y, "output") + + node_0_file = tempfile.NamedTemporaryFile("wb") + node_0_file.write(ser_task) + node_0_file.flush() + + deps_file = tempfile.NamedTemporaryFile("wb") + deps_file.write(ser_deps) + deps_file.flush() + + cb_file = tempfile.NamedTemporaryFile("wb") + cb_file.write(ser_cb) + cb_file.flush() + + ca_file = tempfile.NamedTemporaryFile("wb") + ca_file.write(ser_ca) + ca_file.flush() + + node_1_file = tempfile.NamedTemporaryFile("wb") + node_1_file.write(ser_x) + node_1_file.flush() + + node_2_file = tempfile.NamedTemporaryFile("wb") + node_2_file.write(ser_y) + node_2_file.flush() + + task_spec = TaskSpec( + function_id=0, + args_ids=[1, 2], + kwargs_ids={}, + deps_id="deps", + call_before_id="call_before", + call_after_id="call_after", + ) + + resources = ResourceMap( + functions={ + 0: node_0_file.name, + }, + inputs={ + 1: node_1_file.name, + 2: node_2_file.name, + }, + deps={ + "deps": deps_file.name, + "call_before": cb_file.name, + "call_after": ca_file.name, + }, + ) + + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": task_group_id, + } + + async def run_dask_job(task_specs, resources, task_group_metadata): + job_id = await dask_exec.send(task_specs, resources, task_group_metadata) + job_status = await dask_exec.poll(task_group_metadata, job_id) + task_updates = await dask_exec.receive(task_group_metadata, job_status) + return job_status, task_updates + + job_status, task_updates = asyncio.run( + run_dask_job([task_spec], resources, task_group_metadata) + ) + + assert job_status["status"] == "READY" + assert len(task_updates) == 1 + task_update = task_updates[0] + assert str(task_update.status) == (ct.status.COMPLETED) + output_uri = task_update.assets["output"].remote_uri + + with open(output_uri, "rb") as f: + output = TransportableObject.deserialize(f.read()) + assert output.get_deserialized() == 3 + + +def test_run_task_from_uris_alt(): + """Test the wrapper submitted to dask""" + + def task(x, y): + return x + y + + dispatch_id = "test_dask_send_receive" + node_id = 0 + task_group_id = 0 + + x = TransportableObject(1) + y = TransportableObject(2) + deps = {} + + cb_tmpfile = tempfile.NamedTemporaryFile() + ca_tmpfile = tempfile.NamedTemporaryFile() + + call_before = [ct.DepsBash([f"echo Hello > {cb_tmpfile.name}"]).to_dict()] + call_after = [ct.DepsBash(f"echo Bye > {ca_tmpfile.name}").to_dict()] + + ser_task = serialize_node_asset(TransportableObject(task), "function") + ser_deps = serialize_node_asset(deps, "deps") + ser_cb = serialize_node_asset(call_before, "call_before") + ser_ca = serialize_node_asset(call_after, "call_after") + ser_x = serialize_node_asset(x, "output") + ser_y = serialize_node_asset(y, "output") + + node_0_file = tempfile.NamedTemporaryFile("wb") + node_0_file.write(ser_task) + node_0_file.flush() + + deps_file = tempfile.NamedTemporaryFile("wb") + deps_file.write(ser_deps) + deps_file.flush() + + cb_file = tempfile.NamedTemporaryFile("wb") + cb_file.write(ser_cb) + cb_file.flush() + + ca_file = tempfile.NamedTemporaryFile("wb") + ca_file.write(ser_ca) + ca_file.flush() + + node_1_file = tempfile.NamedTemporaryFile("wb") + node_1_file.write(ser_x) + node_1_file.flush() + + node_2_file = tempfile.NamedTemporaryFile("wb") + node_2_file.write(ser_y) + node_2_file.flush() + + task_spec = TaskSpec( + function_id=0, + args_ids=[1, 2], + kwargs_ids={}, + deps_id="deps", + call_before_id="call_before", + call_after_id="call_after", + ) + + resources = ResourceMap( + functions={ + 0: node_0_file.name, + }, + inputs={ + 1: node_1_file.name, + 2: node_2_file.name, + }, + deps={ + "deps": deps_file.name, + "call_before": cb_file.name, + "call_after": ca_file.name, + }, + ) + + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": task_group_id, + } + + result_file = tempfile.NamedTemporaryFile() + stdout_file = tempfile.NamedTemporaryFile() + stderr_file = tempfile.NamedTemporaryFile() + + results_dir = tempfile.TemporaryDirectory() + + run_task_from_uris_alt( + task_specs=[task_spec.dict()], + resources=resources.dict(), + output_uris=[(result_file.name, stdout_file.name, stderr_file.name)], + results_dir=results_dir.name, + task_group_metadata=task_group_metadata, + server_url="http://localhost:48008", + ) + + with open(result_file.name, "rb") as f: + output = TransportableObject.deserialize(f.read()) + assert output.get_deserialized() == 3 + + with open(cb_tmpfile.name, "r") as f: + assert f.read() == "Hello\n" + + with open(ca_tmpfile.name, "r") as f: + assert f.read() == "Bye\n" + + +def test_run_task_from_uris_alt_exception(): + """Test the wrapper submitted to dask""" + + def task(x, y): + assert False + + dispatch_id = "test_dask_send_receive" + node_id = 0 + task_group_id = 0 + + x = TransportableObject(1) + y = TransportableObject(2) + deps = {} + + cb_tmpfile = tempfile.NamedTemporaryFile() + ca_tmpfile = tempfile.NamedTemporaryFile() + + call_before = [ct.DepsBash([f"echo Hello > {cb_tmpfile.name}"]).to_dict()] + call_after = [ct.DepsBash(f"echo Bye > {ca_tmpfile.name}").to_dict()] + + ser_task = serialize_node_asset(TransportableObject(task), "function") + ser_deps = serialize_node_asset(deps, "deps") + ser_cb = serialize_node_asset(call_before, "call_before") + ser_ca = serialize_node_asset(call_after, "call_after") + ser_x = serialize_node_asset(x, "output") + ser_y = serialize_node_asset(y, "output") + + node_0_file = tempfile.NamedTemporaryFile("wb") + node_0_file.write(ser_task) + node_0_file.flush() + + deps_file = tempfile.NamedTemporaryFile("wb") + deps_file.write(ser_deps) + deps_file.flush() + + cb_file = tempfile.NamedTemporaryFile("wb") + cb_file.write(ser_cb) + cb_file.flush() + + ca_file = tempfile.NamedTemporaryFile("wb") + ca_file.write(ser_ca) + ca_file.flush() + + node_1_file = tempfile.NamedTemporaryFile("wb") + node_1_file.write(ser_x) + node_1_file.flush() + + node_2_file = tempfile.NamedTemporaryFile("wb") + node_2_file.write(ser_y) + node_2_file.flush() + + task_spec = TaskSpec( + function_id=0, + args_ids=[1], + kwargs_ids={"y": 2}, + deps_id="deps", + call_before_id="call_before", + call_after_id="call_after", + ) + + resources = ResourceMap( + functions={ + 0: f"file://{node_0_file.name}", + }, + inputs={ + 1: f"file://{node_1_file.name}", + 2: f"file://{node_2_file.name}", + }, + deps={ + "deps": f"file://{deps_file.name}", + "call_before": f"file://{cb_file.name}", + "call_after": f"file://{ca_file.name}", + }, + ) + + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": task_group_id, + } + + result_file = tempfile.NamedTemporaryFile() + stdout_file = tempfile.NamedTemporaryFile() + stderr_file = tempfile.NamedTemporaryFile() + + results_dir = tempfile.TemporaryDirectory() + + run_task_from_uris_alt( + task_specs=[task_spec.dict()], + resources=resources.dict(), + output_uris=[(result_file.name, stdout_file.name, stderr_file.name)], + results_dir=results_dir.name, + task_group_metadata=task_group_metadata, + server_url="http://localhost:48008", + ) + summary_file_path = f"{results_dir.name}/result-{dispatch_id}:{node_id}.json" + + with open(summary_file_path, "r") as f: + summary = json.load(f) + assert summary["exception_occurred"] is True diff --git a/tests/covalent_tests/executor/executor_plugins/local_test.py b/tests/covalent_tests/executor/executor_plugins/local_test.py index 82ce2d76c2..9a9202ad68 100644 --- a/tests/covalent_tests/executor/executor_plugins/local_test.py +++ b/tests/covalent_tests/executor/executor_plugins/local_test.py @@ -21,6 +21,7 @@ """Tests for Covalent local executor.""" import io +import json import os import tempfile from functools import partial @@ -32,8 +33,14 @@ from covalent._shared_files import TaskRuntimeError from covalent._shared_files.exceptions import TaskCancelledError from covalent._workflow.transport import TransportableObject -from covalent.executor.base import wrapper_fn -from covalent.executor.executor_plugins.local import _EXECUTOR_PLUGIN_DEFAULTS, LocalExecutor +from covalent.executor.executor_plugins.local import ( + _EXECUTOR_PLUGIN_DEFAULTS, + LocalExecutor, + TaskSpec, + run_task_from_uris, +) +from covalent.executor.utils.serialize import serialize_node_asset +from covalent.executor.utils.wrappers import wrapper_fn def test_local_executor_init(mocker): @@ -231,3 +238,251 @@ def test_local_executor_get_cancel_requested(mocker): le.run(local_executor_run__mock_task, args, kwargs, task_metadata) le.get_cancel_requested.assert_called_once() assert mock_app_log.call_count == 2 + + +def test_run_task_from_uris(mocker): + """Test the wrapper submitted to local""" + + def task(x, y): + return x + y + + dispatch_id = "test_dask_send_receive" + node_id = 0 + task_group_id = 0 + server_url = "http://localhost:48008" + + x = TransportableObject(1) + y = TransportableObject(2) + deps = {} + + cb_tmpfile = tempfile.NamedTemporaryFile() + ca_tmpfile = tempfile.NamedTemporaryFile() + + deps = { + "bash": ct.DepsBash([f"echo Hello > {cb_tmpfile.name}"]).to_dict(), + "pip": ct.DepsBash(f"echo Bye > {ca_tmpfile.name}").to_dict(), + } + + call_before = [] + call_after = [] + + ser_task = serialize_node_asset(TransportableObject(task), "function") + ser_deps = serialize_node_asset(deps, "deps") + ser_cb = serialize_node_asset(call_before, "call_before") + ser_ca = serialize_node_asset(call_after, "call_after") + ser_x = serialize_node_asset(x, "output") + ser_y = serialize_node_asset(y, "output") + + node_0_file = tempfile.NamedTemporaryFile("wb") + node_0_file.write(ser_task) + node_0_file.flush() + node_0_function_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/function" + + deps_file = tempfile.NamedTemporaryFile("wb") + deps_file.write(ser_deps) + deps_file.flush() + deps_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/deps" + + cb_file = tempfile.NamedTemporaryFile("wb") + cb_file.write(ser_cb) + cb_file.flush() + cb_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/call_before" + + ca_file = tempfile.NamedTemporaryFile("wb") + ca_file.write(ser_ca) + ca_file.flush() + ca_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/call_after" + + node_1_file = tempfile.NamedTemporaryFile("wb") + node_1_file.write(ser_x) + node_1_file.flush() + node_1_output_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/1/assets/output" + + node_2_file = tempfile.NamedTemporaryFile("wb") + node_2_file.write(ser_y) + node_2_file.flush() + node_2_output_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/2/assets/output" + + task_spec = TaskSpec( + function_id=0, + args_ids=[1, 2], + kwargs_ids={}, + deps_id="deps", + call_before_id="call_before", + call_after_id="call_after", + ) + + resources = { + node_0_function_url: ser_task, + node_1_output_url: ser_x, + node_2_output_url: ser_y, + deps_url: ser_deps, + cb_url: ser_cb, + ca_url: ser_ca, + } + + def mock_req_get(url, stream): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = resources[url] + return mock_resp + + def mock_req_post(url, files): + resources[url] = files["asset_file"].read() + + mocker.patch("requests.get", mock_req_get) + mocker.patch("requests.post", mock_req_post) + mock_put = mocker.patch("requests.put") + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": task_group_id, + } + + result_file = tempfile.NamedTemporaryFile() + stdout_file = tempfile.NamedTemporaryFile() + stderr_file = tempfile.NamedTemporaryFile() + + results_dir = tempfile.TemporaryDirectory() + + run_task_from_uris( + task_specs=[task_spec.dict()], + resources={}, + output_uris=[(result_file.name, stdout_file.name, stderr_file.name)], + results_dir=results_dir.name, + task_group_metadata=task_group_metadata, + server_url=server_url, + ) + + with open(result_file.name, "rb") as f: + output = TransportableObject.deserialize(f.read()) + assert output.get_deserialized() == 3 + + with open(cb_tmpfile.name, "r") as f: + assert f.read() == "Hello\n" + + with open(ca_tmpfile.name, "r") as f: + assert f.read() == "Bye\n" + + mock_put.assert_called() + + +def test_run_task_from_uris_exception(mocker): + """Test the wrapper submitted to local""" + + def task(x, y): + assert False + + dispatch_id = "test_dask_send_receive" + node_id = 0 + task_group_id = 0 + server_url = "http://localhost:48008" + + x = TransportableObject(1) + y = TransportableObject(2) + deps = {} + + cb_tmpfile = tempfile.NamedTemporaryFile() + ca_tmpfile = tempfile.NamedTemporaryFile() + + deps = { + "bash": ct.DepsBash([f"echo Hello > {cb_tmpfile.name}"]).to_dict(), + "pip": ct.DepsBash(f"echo Bye > {ca_tmpfile.name}").to_dict(), + } + + call_before = [] + call_after = [] + + ser_task = serialize_node_asset(TransportableObject(task), "function") + ser_deps = serialize_node_asset(deps, "deps") + ser_cb = serialize_node_asset(call_before, "call_before") + ser_ca = serialize_node_asset(call_after, "call_after") + ser_x = serialize_node_asset(x, "output") + ser_y = serialize_node_asset(y, "output") + + node_0_file = tempfile.NamedTemporaryFile("wb") + node_0_file.write(ser_task) + node_0_file.flush() + node_0_function_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/function" + + deps_file = tempfile.NamedTemporaryFile("wb") + deps_file.write(ser_deps) + deps_file.flush() + deps_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/deps" + + cb_file = tempfile.NamedTemporaryFile("wb") + cb_file.write(ser_cb) + cb_file.flush() + cb_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/call_before" + + ca_file = tempfile.NamedTemporaryFile("wb") + ca_file.write(ser_ca) + ca_file.flush() + ca_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/0/assets/call_after" + + node_1_file = tempfile.NamedTemporaryFile("wb") + node_1_file.write(ser_x) + node_1_file.flush() + node_1_output_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/1/assets/output" + + node_2_file = tempfile.NamedTemporaryFile("wb") + node_2_file.write(ser_y) + node_2_file.flush() + node_2_output_url = f"{server_url}/api/v1/dispatch/{dispatch_id}/electron/2/assets/output" + + task_spec = TaskSpec( + function_id=0, + args_ids=[1], + kwargs_ids={"y": 2}, + deps_id="deps", + call_before_id="call_before", + call_after_id="call_after", + ) + + resources = { + node_0_function_url: ser_task, + node_1_output_url: ser_x, + node_2_output_url: ser_y, + deps_url: ser_deps, + cb_url: ser_cb, + ca_url: ser_ca, + } + + def mock_req_get(url, stream): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = resources[url] + return mock_resp + + def mock_req_post(url, files): + resources[url] = files["asset_file"].read() + + mocker.patch("requests.get", mock_req_get) + mocker.patch("requests.post", mock_req_post) + mocker.patch("requests.put") + task_group_metadata = { + "dispatch_id": dispatch_id, + "node_ids": [node_id], + "task_group_id": task_group_id, + } + + result_file = tempfile.NamedTemporaryFile() + stdout_file = tempfile.NamedTemporaryFile() + stderr_file = tempfile.NamedTemporaryFile() + + results_dir = tempfile.TemporaryDirectory() + + run_task_from_uris( + task_specs=[task_spec.dict()], + resources={}, + output_uris=[(result_file.name, stdout_file.name, stderr_file.name)], + results_dir=results_dir.name, + task_group_metadata=task_group_metadata, + server_url=server_url, + ) + + summary_file_path = f"{results_dir.name}/result-{dispatch_id}:{node_id}.json" + + with open(summary_file_path, "r") as f: + summary = json.load(f) + assert summary["exception_occurred"] is True