Skip to content

Commit

Permalink
Move execute_task() to a dedicated module
Browse files Browse the repository at this point in the history
Multiple executors use the `execute_task()` function, so moving it to
its own module improves code organization and reusability.

Also removed MPI-related code from `execute_task()`, as it's specific to
the HTEX.
  • Loading branch information
rjmello committed Nov 19, 2024
1 parent 9fb5269 commit ca0bbbe
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 43 deletions.
36 changes: 36 additions & 0 deletions parsl/executors/execute_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

from parsl.serialize import unpack_res_spec_apply_message


def execute_task(bufs: bytes):
"""Deserialize the buffer and execute the task.
Returns the result or throws exception.
"""
user_ns = locals()
user_ns.update({'__builtins__': __builtins__})

f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, copy=False)

for varname in resource_spec:
envname = "PARSL_" + str(varname).upper()
os.environ[envname] = str(resource_spec[varname])

# We might need to look into callability of the function from itself
# since we change it's name in the new namespace
prefix = "parsl_"
fname = prefix + "f"
argname = prefix + "args"
kwargname = prefix + "kwargs"
resultname = prefix + "result"

user_ns.update({fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname})

code = "{0} = {1}(*{2}, **{3})".format(resultname, fname,
argname, kwargname)
exec(code, user_ns, user_ns)
return user_ns.get(resultname)
2 changes: 1 addition & 1 deletion parsl/executors/flux/execute_parsl_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import logging
import os

from parsl.executors.execute_task import execute_task
from parsl.executors.flux import TaskResult
from parsl.executors.high_throughput.process_worker_pool import execute_task
from parsl.serialize import serialize


Expand Down
54 changes: 13 additions & 41 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from parsl import curvezmq
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.execute_task import execute_task
from parsl.executors.high_throughput.errors import WorkerLost
from parsl.executors.high_throughput.mpi_prefix_composer import (
VALID_LAUNCHERS,
Expand All @@ -35,7 +36,7 @@
from parsl.executors.high_throughput.probe import probe_addresses
from parsl.multiprocessing import SpawnContext
from parsl.process_loggers import wrap_with_logs
from parsl.serialize import serialize, unpack_res_spec_apply_message
from parsl.serialize import serialize
from parsl.version import VERSION as PARSL_VERSION

HEARTBEAT_CODE = (2 ** 32) - 1
Expand Down Expand Up @@ -590,45 +591,13 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i
os.environ[key] = prefix_table[key]


def execute_task(bufs, mpi_launcher: Optional[str] = None):
"""Deserialize the buffer and execute the task.
Returns the result or throws exception.
"""
user_ns = locals()
user_ns.update({'__builtins__': __builtins__})

f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False)

for varname in resource_spec:
envname = "PARSL_" + str(varname).upper()
os.environ[envname] = str(resource_spec[varname])

if resource_spec.get("MPI_NODELIST"):
worker_id = os.environ['PARSL_WORKER_RANK']
nodes_for_task = resource_spec["MPI_NODELIST"].split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
assert mpi_launcher
update_resource_spec_env_vars(mpi_launcher,
resource_spec=resource_spec,
node_info=nodes_for_task)
# We might need to look into callability of the function from itself
# since we change it's name in the new namespace
prefix = "parsl_"
fname = prefix + "f"
argname = prefix + "args"
kwargname = prefix + "kwargs"
resultname = prefix + "result"

user_ns.update({fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname})

code = "{0} = {1}(*{2}, **{3})".format(resultname, fname,
argname, kwargname)
exec(code, user_ns, user_ns)
return user_ns.get(resultname)
def _init_mpi_env(mpi_launcher: str, resource_spec: Dict):
node_list = resource_spec.get("MPI_NODELIST")
if node_list is None:
return
nodes_for_task = node_list.split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task)


@wrap_with_logs(target="worker_log")
Expand Down Expand Up @@ -786,8 +755,11 @@ def manager_is_alive():
ready_worker_count.value -= 1
worker_enqueued = False

resource_spec = req["resource_spec"]
_init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=resource_spec)

try:
result = execute_task(req['buffer'], mpi_launcher=mpi_launcher)
result = execute_task(req['buffer'])
serialized_result = serialize(result, buffer_threshold=1000000)
except Exception as e:
logger.info('Caught an exception: {}'.format(e))
Expand Down
2 changes: 1 addition & 1 deletion parsl/executors/radical/rpex_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import parsl.app.errors as pe
from parsl.app.bash import remote_side_bash_executor
from parsl.executors.high_throughput.process_worker_pool import execute_task
from parsl.executors.execute_task import execute_task
from parsl.serialize import serialize, unpack_res_spec_apply_message


Expand Down
29 changes: 29 additions & 0 deletions parsl/tests/test_execute_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

import pytest

from parsl.executors.execute_task import execute_task
from parsl.serialize.facade import pack_res_spec_apply_message


def addemup(*args: int, name: str = "apples"):
total = sum(args)
return f"{total} {name}"


@pytest.mark.local
def test_execute_task():
args = (1, 2, 3)
kwargs = {"name": "boots"}
buff = pack_res_spec_apply_message(addemup, args, kwargs, {})
res = execute_task(buff)
assert res == addemup(*args, **kwargs)


@pytest.mark.local
def test_execute_task_resource_spec():
resource_spec = {"num_nodes": 2, "ranks_per_node": 2, "num_ranks": 4}
buff = pack_res_spec_apply_message(addemup, (1, 2), {}, resource_spec)
execute_task(buff)
for key, val in resource_spec.items():
assert os.environ[f"PARSL_{key.upper()}"] == str(val)

0 comments on commit ca0bbbe

Please sign in to comment.