Skip to content

Commit

Permalink
Fix async bug (#43)
Browse files Browse the repository at this point in the history
* bug fix

* Update to ensure that default values are read in from the config file
  • Loading branch information
FyzHsn authored Oct 20, 2022
1 parent 3cc962b commit 6248f0a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Fixed

- Ensure that async functions are not being passed off to the threadpool.

## [0.16.0] - 2022-10-14

### Changed
Expand Down
26 changes: 15 additions & 11 deletions covalent_ecs_plugin/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from covalent._shared_files.logger import app_log
from covalent_aws_plugins import AWSExecutor

from .utils import _execute_partial_in_threadpool
from .utils import _execute_partial_in_threadpool, _load_pickle_file

_EXECUTOR_PLUGIN_DEFAULTS = {
"credentials": os.environ.get("AWS_SHARED_CREDENTIALS_FILE")
Expand Down Expand Up @@ -85,8 +85,8 @@ class ECSExecutor(AWSExecutor):

def __init__(
self,
s3_bucket_name: str,
ecs_task_security_group_id: str,
s3_bucket_name: str = None,
ecs_task_security_group_id: str = None,
ecs_cluster_name: str = None,
ecs_task_family_name: str = None,
ecs_task_execution_role_name: str = None,
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
f"{self.ecs_task_security_group_id} is not a valid security group id. Please set a valid security group id either in the ECS executor definition or in the Covalent config file."
)

async def _upload_task_to_s3(self, dispatch_id, node_id, function, args, kwargs) -> None:
def _upload_task_to_s3(self, dispatch_id, node_id, function, args, kwargs) -> None:
"""Upload task to S3."""
s3 = boto3.Session(**self.boto_session_options()).client("s3")
s3_object_filename = FUNC_FILENAME.format(dispatch_id=dispatch_id, node_id=node_id)
Expand Down Expand Up @@ -279,8 +279,9 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata:
self._debug_log(f"Successfully submitted task with ARN: {task_arn}")

await self._poll_task(task_arn)
partial_func = partial(self.query_result, task_metadata)
return await _execute_partial_in_threadpool(partial_func)

self._debug_log("Querying result...")
return await self.query_result(task_metadata)

async def get_status(self, task_arn: str) -> Tuple[str, int]:
"""Query the status of a previously submitted ECS task.
Expand Down Expand Up @@ -364,7 +365,6 @@ async def query_result(self, task_metadata: Dict) -> Tuple[Any, str, str]:
Returns:
result: The task's result, as a Python object.
"""

s3 = boto3.Session(**self.boto_session_options()).client("s3")

dispatch_id = task_metadata["dispatch_id"]
Expand All @@ -375,10 +375,14 @@ async def query_result(self, task_metadata: Dict) -> Tuple[Any, str, str]:
self._debug_log(
f"Downloading {result_filename} from bucket {self.s3_bucket_name} to local path ${local_result_filename}"
)
s3.download_file(self.s3_bucket_name, result_filename, local_result_filename)
with open(local_result_filename, "rb") as f:
result = pickle.load(f)
os.remove(local_result_filename)
partial_func = partial(
s3.download_file, self.s3_bucket_name, result_filename, local_result_filename
)
await _execute_partial_in_threadpool(partial_func)

result = await _execute_partial_in_threadpool(
partial(_load_pickle_file, local_result_filename)
)
return result

async def cancel(self, task_arn: str, reason: str = "None") -> None:
Expand Down
11 changes: 11 additions & 0 deletions covalent_ecs_plugin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,19 @@
"""Helper methods for ECS executor plugin."""

import asyncio
import os

import cloudpickle as pickle


async def _execute_partial_in_threadpool(partial_func):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial_func)


def _load_pickle_file(filename):
"""Method to load the pickle file."""
with open(filename, "rb") as f:
result = pickle.load(f)
os.remove(filename)
return result
2 changes: 1 addition & 1 deletion tests/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_upload_file_to_s3(self, mock_executor, mocker):
def some_function():
pass

await mock_executor._upload_task_to_s3(
mock_executor._upload_task_to_s3(
some_function,
self.MOCK_DISPATCH_ID,
self.MOCK_NODE_ID,
Expand Down
16 changes: 15 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
"""Unit tests for AWS ECS executor utils file."""

from functools import partial
from pathlib import Path

import cloudpickle as pickle
import pytest

from covalent_ecs_plugin.utils import _execute_partial_in_threadpool
from covalent_ecs_plugin.utils import _execute_partial_in_threadpool, _load_pickle_file


@pytest.mark.asyncio
Expand All @@ -37,3 +39,15 @@ def test_func(x):
partial_func = partial(test_func, x=1)
future = await _execute_partial_in_threadpool(partial_func)
assert future == 1


def test_load_pickle_file(mocker):
"""Test the method used to load the pickled file and delete the file afterwards."""
temp_fp = "/tmp/test.pkl"
with open(temp_fp, "wb") as f:
pickle.dump("test success", f)

assert Path(temp_fp).exists()
res = _load_pickle_file(temp_fp)
assert res == "test success"
assert not Path(temp_fp).exists()

0 comments on commit 6248f0a

Please sign in to comment.