Skip to content

Commit

Permalink
Reorg BaseExecutors (#1169)
Browse files Browse the repository at this point in the history
* Reorganize BaseExecutors

* Move `log_stdout`, `log_stderr`, `cache_dir` to
`_AbstractBaseExecutor`

* Add `time_limit` and `retries` to `_AbstractBaseExecutor`

* Remove conda env logic from `BaseExecutor`

* Update changelog
  • Loading branch information
cjao authored Sep 6, 2022
1 parent efb2ab2 commit d6a6a9b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 341 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Changed

- Refactored executor base classes

## [0.192.0] - 2022-09-02

### Authors
Expand Down
276 changes: 47 additions & 229 deletions covalent/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,33 @@ def wrapper_fn(

class _AbstractBaseExecutor(ABC):
"""
Private class that contains attributes and methods common to both
BaseExecutor and AsyncBaseExecutor
Private parent class for BaseExecutor and AsyncBaseExecutor
Attributes:
log_stdout: The path to the file to be used for redirecting stdout.
log_stderr: The path to the file to be used for redirecting stderr.
cache_dir: The location used for cached files in the executor.
time_limit: time limit for the task
retries: Number of times to retry execution upon failure
"""

def __init__(
self,
log_stdout: str = "",
log_stderr: str = "",
cache_dir: str = "",
time_limit: int = -1,
retries: int = 0,
*args,
**kwargs,
):
self.log_stdout = log_stdout
self.log_stderr = log_stderr
self.cache_dir = cache_dir
self.time_limit = time_limit
self.retries = retries

def get_dispatch_context(self, dispatch_info: DispatchInfo) -> ContextManager[DispatchInfo]:
"""
Start a context manager that will be used to
Expand Down Expand Up @@ -158,36 +181,22 @@ class BaseExecutor(_AbstractBaseExecutor):
plugin. Subclassing this class will allow you to define
your own executor plugin which can be used in covalent.
Note: When using a conda environment, it is assumed that
covalent with all its dependencies are also installed in
that environment.
Attributes:
log_stdout: The path to the file to be used for redirecting stdout.
log_stderr: The path to the file to be used for redirecting stderr.
conda_env: The name of the Conda environment to be used.
cache_dir: The location used for cached files in the executor.
current_env_on_conda_fail: If True, the current environment will be used
if conda fails to activate specified env.
time_limit: time limit for the task
retries: Number of times to retry execution upon failure
"""

def __init__(
self,
log_stdout: str = "",
log_stderr: str = "",
conda_env: str = "",
cache_dir: str = "",
current_env_on_conda_fail: bool = False,
*args,
**kwargs,
) -> None:

self.log_stdout = log_stdout
self.log_stderr = log_stderr
self.conda_env = conda_env
self.cache_dir = cache_dir
self.current_env_on_conda_fail = current_env_on_conda_fail
self.current_env = ""
super().__init__(*args, **kwargs)

def write_streams_to_file(
self,
Expand Down Expand Up @@ -261,23 +270,9 @@ def execute(
io.StringIO()
) as stdout, redirect_stderr(io.StringIO()) as stderr:

if self.conda_env != "":
result = None

result = self.execute_in_conda_env(
function,
fn_version,
args,
kwargs,
self.conda_env,
self.cache_dir,
node_id,
)

else:
self.setup(task_metadata=task_metadata)
result = self.run(function, args, kwargs, task_metadata)
self.teardown(task_metadata=task_metadata)
self.setup(task_metadata=task_metadata)
result = self.run(function, args, kwargs, task_metadata)
self.teardown(task_metadata=task_metadata)

self.write_streams_to_file(
(stdout.getvalue(), stderr.getvalue()),
Expand Down Expand Up @@ -313,209 +308,32 @@ def teardown(self, task_metadata: Dict) -> Any:
"""Placeholder to run nay executor specific cleanup/teardown actions"""
pass

def execute_in_conda_env(
self,
fn: Callable,
fn_version: str,
args: List,
kwargs: Dict,
conda_env: str,
cache_dir: str,
node_id: int,
) -> Tuple[bool, Any]:
"""
Execute the function with the given arguments, in a Conda environment.
Args:
fn: The input python function which will be executed and whose result
is ultimately returned by this function.
fn_version: The python version the function was created with.
args: List of positional arguments to be used by the function.
kwargs: Dictionary of keyword arguments to be used by the function.
conda_env: Name of a Conda environment in which to execute the task.
cache_dir: The directory where temporary files and logs (if any) are stored.
node_id: The integer identifier for the current node.
Returns:
output: The result of the function execution.
"""

if not self.get_conda_path():
return self._on_conda_env_fail(fn, args, kwargs, node_id)

# Pickle the function
temp_filename = ""
with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False) as f:
pickle.dump(fn, f)
temp_filename = f.name

result_filename = os.path.join(cache_dir, f'result_{temp_filename.split("/")[-1]}')

# Write a bash script to activate the environment
shell_commands = "#!/bin/bash\n"

# Add commands to initialize the Conda shell and activate the environment:
conda_sh = os.path.join(
os.path.dirname(self.conda_path), "..", "etc", "profile.d", "conda.sh"
)
conda_sh = os.environ.get("CONDA_SHELL", conda_sh)
if os.path.exists(conda_sh):
shell_commands += f"source {conda_sh}\n"
else:
message = "No Conda installation found on this compute node."
app_log.warning(message)
return self._on_conda_env_fail(fn, args, kwargs, node_id)

shell_commands += f"conda activate {conda_env}\n"
shell_commands += "retval=$?\n"
shell_commands += "if [ $retval -ne 0 ]; then\n"
shell_commands += (
f' echo "Conda environment {conda_env} is not present on this compute node."\n'
)
shell_commands += ' echo "Please create that environment (or use an existing environment) and try again."\n'
shell_commands += " exit 99\n"
shell_commands += "fi\n\n"

# Check Python version and give a warning if there is a mismatch:
shell_commands += "py_version=`python -V | awk '{{print $2}}'`\n"
shell_commands += f'if [[ "{fn_version}" != "$py_version" ]]; then\n'
shell_commands += ' echo "Warning: Python version mismatch:"\n'
shell_commands += f' echo "Workflow version is {fn_version}. Conda environment version is $py_version."\n'
shell_commands += "fi\n\n"

shell_commands += "python - <<EOF\n"
shell_commands += "import cloudpickle as pickle\n"
shell_commands += "import os\n\n"

# Add Python commands to run the pickled function:
shell_commands += f'with open("{temp_filename}", "rb") as f:\n'
shell_commands += " fn = pickle.load(f)\n\n"

shell_commands += f"result = fn(*{args}, **{kwargs})\n\n"

shell_commands += f'with open("{result_filename}", "wb") as f:\n'
shell_commands += " pickle.dump(result, f)\n"
shell_commands += "EOF\n"

# Remove the temp file
os.remove(temp_filename)

# Run the script and unpickle the result
with tempfile.NamedTemporaryFile(dir=cache_dir, mode="w") as f:
f.write(shell_commands)
f.flush()

out = subprocess.run(["bash", f.name], capture_output=True, encoding="utf-8")

if len(out.stdout) != 0:
# These are print/log statements from the task.
print(out.stdout)

if out.returncode != 0:
app_log.warning(out.stderr)
return self._on_conda_env_fail(fn, args, kwargs, node_id)

with open(result_filename, "rb") as f:
result = pickle.load(f)

message = f"Executed node {node_id} on Conda environment {self.conda_env}."
app_log.debug(message)
return result

def _on_conda_env_fail(self, fn: Callable, args: List, kwargs: Dict, node_id: int):
"""
Args:
fn: The input python function which will be executed and
whose result may be returned by this function.
args: List of positional arguments to be used by the function.
kwargs: Dictionary of keyword arguments to be used by the function.
node_id: The integer identifier for the current node.
Returns:
output: The result of the function execution, if
self.current_env_on_conda_fail == True, otherwise, return value is None.
"""

result = None
message = f"Failed to execute node {node_id} on Conda environment {self.conda_env}."
if self.current_env_on_conda_fail:
message += "\nExecuting on the current Conda environment."
app_log.warning(message)
result = fn(*args, **kwargs)

else:
app_log.error(message)
raise RuntimeError

return result

def get_conda_envs(self) -> None:
"""
Print a list of Conda environments detected on the system.
Args:
None
Returns:
None
"""

self.conda_envs = []

env_output = subprocess.run(
["conda", "env", "list"], capture_output=True, encoding="utf-8"
)

if len(env_output.stderr) > 0:
message = f"Problem in listing Conda environments:\n{env_output.stderr}"
app_log.warning(message)
return

for line in env_output.stdout.split("\n"):
if not line.startswith("#"):
row = line.split()
if len(row) > 1:
if "*" in row:
self.current_env = row[0]
self.conda_envs.append(row[0])

app_log.debug(f"Conda environments:\n{self.conda_envs}")

def get_conda_path(self) -> bool:
"""
Query the path where the conda executable can be found.
class AsyncBaseExecutor(_AbstractBaseExecutor):
"""Async base executor class to be used for defining any executor
plugin. Subclassing this class will allow you to define
your own executor plugin which can be used in covalent.
Args:
None
This is analogous to `BaseExecutor` except the `run()` method,
together with the optional `setup()` and `teardown()` methods, are
coroutines.
Returns:
found: True if Conda is found on the system.
"""

self.conda_path = ""
which_conda = subprocess.run(
["which", "conda"], capture_output=True, encoding="utf-8"
).stdout
if which_conda == "":
message = "No Conda installation found on this compute node."
app_log.warning(message)
return False
self.conda_path = which_conda
return True
Attributes:
log_stdout: The path to the file to be used for redirecting stdout.
log_stderr: The path to the file to be used for redirecting stderr.
cache_dir: The location used for cached files in the executor.
time_limit: time limit for the task
retries: Number of times to retry execution upon failure
"""

class AsyncBaseExecutor(_AbstractBaseExecutor):
def __init__(
self,
log_stdout: str = "",
log_stderr: str = "",
*args,
**kwargs,
) -> None:

self.log_stdout = log_stdout
self.log_stderr = log_stderr
super().__init__(*args, **kwargs)

async def write_streams_to_file(
self,
Expand Down
Loading

0 comments on commit d6a6a9b

Please sign in to comment.