From 6248f0a520d222a11faad8add80dce682a66eaea Mon Sep 17 00:00:00 2001 From: Faiyaz Hasan Date: Thu, 20 Oct 2022 19:23:18 -0400 Subject: [PATCH] Fix async bug (#43) * bug fix * Update to ensure that default values are read in from the config file --- CHANGELOG.md | 4 ++++ covalent_ecs_plugin/ecs.py | 26 +++++++++++++++----------- covalent_ecs_plugin/utils.py | 11 +++++++++++ tests/test_ecs.py | 2 +- tests/test_utils.py | 16 +++++++++++++++- 5 files changed, 46 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3be911..b56f9e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Fixed + +- Ensure that async functions are not being passed off to the threadpool. + ## [0.16.0] - 2022-10-14 ### Changed diff --git a/covalent_ecs_plugin/ecs.py b/covalent_ecs_plugin/ecs.py index b09d7cc..518345e 100644 --- a/covalent_ecs_plugin/ecs.py +++ b/covalent_ecs_plugin/ecs.py @@ -34,7 +34,7 @@ from covalent._shared_files.logger import app_log from covalent_aws_plugins import AWSExecutor -from .utils import _execute_partial_in_threadpool +from .utils import _execute_partial_in_threadpool, _load_pickle_file _EXECUTOR_PLUGIN_DEFAULTS = { "credentials": os.environ.get("AWS_SHARED_CREDENTIALS_FILE") @@ -85,8 +85,8 @@ class ECSExecutor(AWSExecutor): def __init__( self, - s3_bucket_name: str, - ecs_task_security_group_id: str, + s3_bucket_name: str = None, + ecs_task_security_group_id: str = None, ecs_cluster_name: str = None, ecs_task_family_name: str = None, ecs_task_execution_role_name: str = None, @@ -148,7 +148,7 @@ def __init__( f"{self.ecs_task_security_group_id} is not a valid security group id. Please set a valid security group id either in the ECS executor definition or in the Covalent config file." ) - async def _upload_task_to_s3(self, dispatch_id, node_id, function, args, kwargs) -> None: + def _upload_task_to_s3(self, dispatch_id, node_id, function, args, kwargs) -> None: """Upload task to S3.""" s3 = boto3.Session(**self.boto_session_options()).client("s3") s3_object_filename = FUNC_FILENAME.format(dispatch_id=dispatch_id, node_id=node_id) @@ -279,8 +279,9 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: self._debug_log(f"Successfully submitted task with ARN: {task_arn}") await self._poll_task(task_arn) - partial_func = partial(self.query_result, task_metadata) - return await _execute_partial_in_threadpool(partial_func) + + self._debug_log("Querying result...") + return await self.query_result(task_metadata) async def get_status(self, task_arn: str) -> Tuple[str, int]: """Query the status of a previously submitted ECS task. @@ -364,7 +365,6 @@ async def query_result(self, task_metadata: Dict) -> Tuple[Any, str, str]: Returns: result: The task's result, as a Python object. """ - s3 = boto3.Session(**self.boto_session_options()).client("s3") dispatch_id = task_metadata["dispatch_id"] @@ -375,10 +375,14 @@ async def query_result(self, task_metadata: Dict) -> Tuple[Any, str, str]: self._debug_log( f"Downloading {result_filename} from bucket {self.s3_bucket_name} to local path ${local_result_filename}" ) - s3.download_file(self.s3_bucket_name, result_filename, local_result_filename) - with open(local_result_filename, "rb") as f: - result = pickle.load(f) - os.remove(local_result_filename) + partial_func = partial( + s3.download_file, self.s3_bucket_name, result_filename, local_result_filename + ) + await _execute_partial_in_threadpool(partial_func) + + result = await _execute_partial_in_threadpool( + partial(_load_pickle_file, local_result_filename) + ) return result async def cancel(self, task_arn: str, reason: str = "None") -> None: diff --git a/covalent_ecs_plugin/utils.py b/covalent_ecs_plugin/utils.py index c933398..740ebcc 100644 --- a/covalent_ecs_plugin/utils.py +++ b/covalent_ecs_plugin/utils.py @@ -21,8 +21,19 @@ """Helper methods for ECS executor plugin.""" import asyncio +import os + +import cloudpickle as pickle async def _execute_partial_in_threadpool(partial_func): loop = asyncio.get_running_loop() return await loop.run_in_executor(None, partial_func) + + +def _load_pickle_file(filename): + """Method to load the pickle file.""" + with open(filename, "rb") as f: + result = pickle.load(f) + os.remove(filename) + return result diff --git a/tests/test_ecs.py b/tests/test_ecs.py index 1c028e9..e3473ad 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -118,7 +118,7 @@ async def test_upload_file_to_s3(self, mock_executor, mocker): def some_function(): pass - await mock_executor._upload_task_to_s3( + mock_executor._upload_task_to_s3( some_function, self.MOCK_DISPATCH_ID, self.MOCK_NODE_ID, diff --git a/tests/test_utils.py b/tests/test_utils.py index 754d424..fa2f806 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,10 +21,12 @@ """Unit tests for AWS ECS executor utils file.""" from functools import partial +from pathlib import Path +import cloudpickle as pickle import pytest -from covalent_ecs_plugin.utils import _execute_partial_in_threadpool +from covalent_ecs_plugin.utils import _execute_partial_in_threadpool, _load_pickle_file @pytest.mark.asyncio @@ -37,3 +39,15 @@ def test_func(x): partial_func = partial(test_func, x=1) future = await _execute_partial_in_threadpool(partial_func) assert future == 1 + + +def test_load_pickle_file(mocker): + """Test the method used to load the pickled file and delete the file afterwards.""" + temp_fp = "/tmp/test.pkl" + with open(temp_fp, "wb") as f: + pickle.dump("test success", f) + + assert Path(temp_fp).exists() + res = _load_pickle_file(temp_fp) + assert res == "test success" + assert not Path(temp_fp).exists()