diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index 83ea2e31cf..6cec168b5d 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -111,8 +111,6 @@ def __init__(self, config: Config) -> None: self.monitoring = config.monitoring if self.monitoring: - if self.monitoring.logdir is None: - self.monitoring.logdir = self.run_dir self.monitoring.start(self.run_dir, self.config.run_dir) self.time_began = datetime.datetime.now() diff --git a/parsl/executors/execute_task.py b/parsl/executors/execute_task.py new file mode 100644 index 0000000000..41c2f3cc9b --- /dev/null +++ b/parsl/executors/execute_task.py @@ -0,0 +1,37 @@ +import os + +from parsl.serialize import unpack_res_spec_apply_message + + +def execute_task(bufs: bytes): + """Deserialize the buffer and execute the task. + Returns the result or throws exception. + """ + f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs) + + for varname in resource_spec: + envname = "PARSL_" + str(varname).upper() + os.environ[envname] = str(resource_spec[varname]) + + # We might need to look into callability of the function from itself + # since we change it's name in the new namespace + prefix = "parsl_" + fname = prefix + "f" + argname = prefix + "args" + kwargname = prefix + "kwargs" + resultname = prefix + "result" + + code = "{0} = {1}(*{2}, **{3})".format(resultname, fname, + argname, kwargname) + + user_ns = locals() + user_ns.update({ + '__builtins__': __builtins__, + fname: f, + argname: args, + kwargname: kwargs, + resultname: resultname + }) + + exec(code, user_ns, user_ns) + return user_ns.get(resultname) diff --git a/parsl/executors/flux/execute_parsl_task.py b/parsl/executors/flux/execute_parsl_task.py index ddf3c67e14..4372578e26 100644 --- a/parsl/executors/flux/execute_parsl_task.py +++ b/parsl/executors/flux/execute_parsl_task.py @@ -4,8 +4,8 @@ import logging import os +from parsl.executors.execute_task import execute_task from parsl.executors.flux import TaskResult -from parsl.executors.high_throughput.process_worker_pool import execute_task from parsl.serialize import serialize diff --git a/parsl/executors/high_throughput/mpi_resource_management.py b/parsl/executors/high_throughput/mpi_resource_management.py index 3f3fc33ea4..ac02e0c419 100644 --- a/parsl/executors/high_throughput/mpi_resource_management.py +++ b/parsl/executors/high_throughput/mpi_resource_management.py @@ -160,9 +160,7 @@ def put_task(self, task_package: dict): """Schedule task if resources are available otherwise backlog the task""" user_ns = locals() user_ns.update({"__builtins__": __builtins__}) - _f, _args, _kwargs, resource_spec = unpack_res_spec_apply_message( - task_package["buffer"], user_ns, copy=False - ) + _f, _args, _kwargs, resource_spec = unpack_res_spec_apply_message(task_package["buffer"]) nodes_needed = resource_spec.get("num_nodes") if nodes_needed: @@ -177,6 +175,7 @@ def put_task(self, task_package: dict): self._map_tasks_to_nodes[task_package["task_id"]] = allocated_nodes buffer = pack_res_spec_apply_message(_f, _args, _kwargs, resource_spec) task_package["buffer"] = buffer + task_package["resource_spec"] = resource_spec self.pending_task_q.put(task_package) diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index e75af86743..a8bbaa9be8 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -23,6 +23,7 @@ from parsl import curvezmq from parsl.app.errors import RemoteExceptionWrapper +from parsl.executors.execute_task import execute_task from parsl.executors.high_throughput.errors import WorkerLost from parsl.executors.high_throughput.mpi_prefix_composer import ( VALID_LAUNCHERS, @@ -35,7 +36,7 @@ from parsl.executors.high_throughput.probe import probe_addresses from parsl.multiprocessing import SpawnContext from parsl.process_loggers import wrap_with_logs -from parsl.serialize import serialize, unpack_res_spec_apply_message +from parsl.serialize import serialize from parsl.version import VERSION as PARSL_VERSION HEARTBEAT_CODE = (2 ** 32) - 1 @@ -590,45 +591,13 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i os.environ[key] = prefix_table[key] -def execute_task(bufs, mpi_launcher: Optional[str] = None): - """Deserialize the buffer and execute the task. - - Returns the result or throws exception. - """ - user_ns = locals() - user_ns.update({'__builtins__': __builtins__}) - - f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False) - - for varname in resource_spec: - envname = "PARSL_" + str(varname).upper() - os.environ[envname] = str(resource_spec[varname]) - - if resource_spec.get("MPI_NODELIST"): - worker_id = os.environ['PARSL_WORKER_RANK'] - nodes_for_task = resource_spec["MPI_NODELIST"].split(',') - logger.info(f"Launching task on provisioned nodes: {nodes_for_task}") - assert mpi_launcher - update_resource_spec_env_vars(mpi_launcher, - resource_spec=resource_spec, - node_info=nodes_for_task) - # We might need to look into callability of the function from itself - # since we change it's name in the new namespace - prefix = "parsl_" - fname = prefix + "f" - argname = prefix + "args" - kwargname = prefix + "kwargs" - resultname = prefix + "result" - - user_ns.update({fname: f, - argname: args, - kwargname: kwargs, - resultname: resultname}) - - code = "{0} = {1}(*{2}, **{3})".format(resultname, fname, - argname, kwargname) - exec(code, user_ns, user_ns) - return user_ns.get(resultname) +def _init_mpi_env(mpi_launcher: str, resource_spec: Dict): + node_list = resource_spec.get("MPI_NODELIST") + if node_list is None: + return + nodes_for_task = node_list.split(',') + logger.info(f"Launching task on provisioned nodes: {nodes_for_task}") + update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task) @wrap_with_logs(target="worker_log") @@ -786,8 +755,10 @@ def manager_is_alive(): ready_worker_count.value -= 1 worker_enqueued = False + _init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"]) + try: - result = execute_task(req['buffer'], mpi_launcher=mpi_launcher) + result = execute_task(req['buffer']) serialized_result = serialize(result, buffer_threshold=1000000) except Exception as e: logger.info('Caught an exception: {}'.format(e)) diff --git a/parsl/executors/radical/rpex_worker.py b/parsl/executors/radical/rpex_worker.py index db1b7d2bea..90243b558a 100644 --- a/parsl/executors/radical/rpex_worker.py +++ b/parsl/executors/radical/rpex_worker.py @@ -4,7 +4,7 @@ import parsl.app.errors as pe from parsl.app.bash import remote_side_bash_executor -from parsl.executors.high_throughput.process_worker_pool import execute_task +from parsl.executors.execute_task import execute_task from parsl.serialize import serialize, unpack_res_spec_apply_message @@ -33,7 +33,7 @@ def _dispatch_proc(self, task): try: buffer = rp.utils.deserialize_bson(task['description']['executable']) - func, args, kwargs, _resource_spec = unpack_res_spec_apply_message(buffer, {}, copy=False) + func, args, kwargs, _resource_spec = unpack_res_spec_apply_message(buffer) ret = remote_side_bash_executor(func, *args, **kwargs) exc = (None, None) val = None diff --git a/parsl/executors/workqueue/exec_parsl_function.py b/parsl/executors/workqueue/exec_parsl_function.py index 06a86a5c8d..d19d92efe6 100644 --- a/parsl/executors/workqueue/exec_parsl_function.py +++ b/parsl/executors/workqueue/exec_parsl_function.py @@ -94,7 +94,7 @@ def unpack_source_code_function(function_info, user_namespace): def unpack_byte_code_function(function_info, user_namespace): from parsl.serialize import unpack_apply_message - func, args, kwargs = unpack_apply_message(function_info["byte code"], user_namespace, copy=False) + func, args, kwargs = unpack_apply_message(function_info["byte code"]) return (func, 'parsl_function_name', args, kwargs) diff --git a/parsl/monitoring/db_manager.py b/parsl/monitoring/db_manager.py index abdb038e79..8c1d76dbc6 100644 --- a/parsl/monitoring/db_manager.py +++ b/parsl/monitoring/db_manager.py @@ -279,7 +279,7 @@ class Resource(Base): class DatabaseManager: def __init__(self, db_url: str = 'sqlite:///runinfo/monitoring.db', - logdir: str = '.', + run_dir: str = '.', logging_level: int = logging.INFO, batching_interval: float = 1, batching_threshold: float = 99999, @@ -287,12 +287,12 @@ def __init__(self, self.workflow_end = False self.workflow_start_message: Optional[MonitoringMessage] = None - self.logdir = logdir - os.makedirs(self.logdir, exist_ok=True) + self.run_dir = run_dir + os.makedirs(self.run_dir, exist_ok=True) logger.propagate = False - set_file_logger("{}/database_manager.log".format(self.logdir), level=logging_level, + set_file_logger(f"{self.run_dir}/database_manager.log", level=logging_level, format_string="%(asctime)s.%(msecs)03d %(name)s:%(lineno)d [%(levelname)s] [%(threadName)s %(thread)d] %(message)s", name="database_manager") @@ -681,7 +681,7 @@ def close(self) -> None: def dbm_starter(exception_q: mpq.Queue, resource_msgs: mpq.Queue, db_url: str, - logdir: str, + run_dir: str, logging_level: int) -> None: """Start the database manager process @@ -692,7 +692,7 @@ def dbm_starter(exception_q: mpq.Queue, try: dbm = DatabaseManager(db_url=db_url, - logdir=logdir, + run_dir=run_dir, logging_level=logging_level) logger.info("Starting dbm in dbm starter") dbm.start(resource_msgs) diff --git a/parsl/monitoring/monitoring.py b/parsl/monitoring/monitoring.py index a1b20f2705..e82c8fb688 100644 --- a/parsl/monitoring/monitoring.py +++ b/parsl/monitoring/monitoring.py @@ -44,7 +44,6 @@ def __init__(self, workflow_name: Optional[str] = None, workflow_version: Optional[str] = None, logging_endpoint: Optional[str] = None, - logdir: Optional[str] = None, monitoring_debug: bool = False, resource_monitoring_enabled: bool = True, resource_monitoring_interval: float = 30): # in seconds @@ -73,8 +72,6 @@ def __init__(self, The database connection url for monitoring to log the information. These URLs follow RFC-1738, and can include username, password, hostname, database name. Default: sqlite, in the configured run_dir. - logdir : str - Parsl log directory paths. Logs and temp files go here. Default: '.' monitoring_debug : Bool Enable monitoring debug logging. Default: False resource_monitoring_enabled : boolean @@ -96,7 +93,6 @@ def __init__(self, self.hub_port_range = hub_port_range self.logging_endpoint = logging_endpoint - self.logdir = logdir self.monitoring_debug = monitoring_debug self.workflow_name = workflow_name @@ -109,13 +105,10 @@ def start(self, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> No logger.debug("Starting MonitoringHub") - if self.logdir is None: - self.logdir = "." - if self.logging_endpoint is None: self.logging_endpoint = f"sqlite:///{os.fspath(config_run_dir)}/monitoring.db" - os.makedirs(self.logdir, exist_ok=True) + os.makedirs(dfk_run_dir, exist_ok=True) self.monitoring_hub_active = True @@ -151,7 +144,7 @@ def start(self, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> No "hub_address": self.hub_address, "udp_port": self.hub_port, "zmq_port_range": self.hub_port_range, - "logdir": self.logdir, + "run_dir": dfk_run_dir, "logging_level": logging.DEBUG if self.monitoring_debug else logging.INFO, }, name="Monitoring-Router-Process", @@ -161,7 +154,7 @@ def start(self, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> No self.dbm_proc = ForkProcess(target=dbm_starter, args=(self.exception_q, self.resource_msgs,), - kwargs={"logdir": self.logdir, + kwargs={"run_dir": dfk_run_dir, "logging_level": logging.DEBUG if self.monitoring_debug else logging.INFO, "db_url": self.logging_endpoint, }, @@ -172,7 +165,7 @@ def start(self, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> No logger.info("Started the router process %s and DBM process %s", self.router_proc.pid, self.dbm_proc.pid) self.filesystem_proc = ForkProcess(target=filesystem_receiver, - args=(self.logdir, self.resource_msgs, dfk_run_dir), + args=(self.resource_msgs, dfk_run_dir), name="Monitoring-Filesystem-Process", daemon=True ) @@ -258,8 +251,8 @@ def close(self) -> None: @wrap_with_logs -def filesystem_receiver(logdir: str, q: Queue[TaggedMonitoringMessage], run_dir: str) -> None: - logger = set_file_logger("{}/monitoring_filesystem_radio.log".format(logdir), +def filesystem_receiver(q: Queue[TaggedMonitoringMessage], run_dir: str) -> None: + logger = set_file_logger(f"{run_dir}/monitoring_filesystem_radio.log", name="monitoring_filesystem_radio", level=logging.INFO) @@ -270,6 +263,8 @@ def filesystem_receiver(logdir: str, q: Queue[TaggedMonitoringMessage], run_dir: new_dir = f"{base_path}/new/" logger.debug("Creating new and tmp paths under %s", base_path) + target_radio = MultiprocessingQueueRadioSender(q) + os.makedirs(tmp_dir, exist_ok=True) os.makedirs(new_dir, exist_ok=True) @@ -285,7 +280,7 @@ def filesystem_receiver(logdir: str, q: Queue[TaggedMonitoringMessage], run_dir: message = pickle.load(f) logger.debug("Message received is: %s", message) assert isinstance(message, tuple) - q.put(cast(TaggedMonitoringMessage, message)) + target_radio.send(cast(TaggedMonitoringMessage, message)) os.remove(full_path_filename) except Exception: logger.exception("Exception processing %s - probably will be retried next iteration", filename) diff --git a/parsl/monitoring/router.py b/parsl/monitoring/router.py index 1d4b522e82..04e7480a7a 100644 --- a/parsl/monitoring/router.py +++ b/parsl/monitoring/router.py @@ -14,6 +14,7 @@ import zmq from parsl.log_utils import set_file_logger +from parsl.monitoring.radios import MultiprocessingQueueRadioSender from parsl.monitoring.types import TaggedMonitoringMessage from parsl.process_loggers import wrap_with_logs from parsl.utils import setproctitle @@ -30,7 +31,7 @@ def __init__(self, zmq_port_range: Tuple[int, int] = (55050, 56000), monitoring_hub_address: str = "127.0.0.1", - logdir: str = ".", + run_dir: str = ".", logging_level: int = logging.INFO, atexit_timeout: int = 3, # in seconds resource_msgs: mpq.Queue, @@ -47,7 +48,7 @@ def __init__(self, zmq_port_range : tuple(int, int) The MonitoringHub picks ports at random from the range which will be used by Hub. Default: (55050, 56000) - logdir : str + run_dir : str Parsl log directory paths. Logs and temp files go here. Default: '.' logging_level : int Logging level as defined in the logging module. Default: logging.INFO @@ -55,12 +56,11 @@ def __init__(self, The amount of time in seconds to terminate the hub without receiving any messages, after the last dfk workflow message is received. resource_msgs : multiprocessing.Queue A multiprocessing queue to receive messages to be routed onwards to the database process - exit_event : Event An event that the main Parsl process will set to signal that the monitoring router should shut down. """ - os.makedirs(logdir, exist_ok=True) - self.logger = set_file_logger("{}/monitoring_router.log".format(logdir), + os.makedirs(run_dir, exist_ok=True) + self.logger = set_file_logger(f"{run_dir}/monitoring_router.log", name="monitoring_router", level=logging_level) self.logger.debug("Monitoring router starting") @@ -98,7 +98,7 @@ def __init__(self, min_port=zmq_port_range[0], max_port=zmq_port_range[1]) - self.resource_msgs = resource_msgs + self.target_radio = MultiprocessingQueueRadioSender(resource_msgs) self.exit_event = exit_event @wrap_with_logs(target="monitoring_router") @@ -125,7 +125,7 @@ def start_udp_listener(self) -> None: data, addr = self.udp_sock.recvfrom(2048) resource_msg = pickle.loads(data) self.logger.debug("Got UDP Message from {}: {}".format(addr, resource_msg)) - self.resource_msgs.put(resource_msg) + self.target_radio.send(resource_msg) except socket.timeout: pass @@ -136,7 +136,7 @@ def start_udp_listener(self) -> None: data, addr = self.udp_sock.recvfrom(2048) msg = pickle.loads(data) self.logger.debug("Got UDP Message from {}: {}".format(addr, msg)) - self.resource_msgs.put(msg) + self.target_radio.send(msg) last_msg_received_time = time.time() except socket.timeout: pass @@ -160,7 +160,7 @@ def start_zmq_listener(self) -> None: assert len(msg) >= 1, "ZMQ Receiver expects tuples of length at least 1, got {}".format(msg) assert len(msg) == 2, "ZMQ Receiver expects message tuples of exactly length 2, got {}".format(msg) - self.resource_msgs.put(msg) + self.target_radio.send(msg) except zmq.Again: pass except Exception: @@ -187,14 +187,14 @@ def router_starter(*, udp_port: Optional[int], zmq_port_range: Tuple[int, int], - logdir: str, + run_dir: str, logging_level: int) -> None: setproctitle("parsl: monitoring router") try: router = MonitoringRouter(hub_address=hub_address, udp_port=udp_port, zmq_port_range=zmq_port_range, - logdir=logdir, + run_dir=run_dir, logging_level=logging_level, resource_msgs=resource_msgs, exit_event=exit_event) diff --git a/parsl/providers/slurm/slurm.py b/parsl/providers/slurm/slurm.py index 865ca6a52d..9b6f38b9d9 100644 --- a/parsl/providers/slurm/slurm.py +++ b/parsl/providers/slurm/slurm.py @@ -70,6 +70,9 @@ class SlurmProvider(ClusterProvider, RepresentationMixin): Slurm queue to place job in. If unspecified or ``None``, no queue slurm directive will be specified. constraint : str Slurm job constraint, often used to choose cpu or gpu type. If unspecified or ``None``, no constraint slurm directive will be added. + clusters : str + Slurm cluster name, or comma seperated cluster list, used to choose between different clusters in a federated Slurm instance. + If unspecified or ``None``, no slurm directive for clusters will be added. channel : Channel Channel for accessing this provider. nodes_per_block : int @@ -115,6 +118,7 @@ def __init__(self, account: Optional[str] = None, qos: Optional[str] = None, constraint: Optional[str] = None, + clusters: Optional[str] = None, channel: Channel = LocalChannel(), nodes_per_block: int = 1, cores_per_node: Optional[int] = None, @@ -149,6 +153,7 @@ def __init__(self, self.account = account self.qos = qos self.constraint = constraint + self.clusters = clusters self.scheduler_options = scheduler_options + '\n' if exclusive: self.scheduler_options += "#SBATCH --exclusive\n" @@ -160,6 +165,8 @@ def __init__(self, self.scheduler_options += "#SBATCH --qos={}\n".format(qos) if constraint: self.scheduler_options += "#SBATCH --constraint={}\n".format(constraint) + if clusters: + self.scheduler_options += "#SBATCH --clusters={}\n".format(clusters) self.regex_job_id = regex_job_id self.worker_init = worker_init + '\n' @@ -171,14 +178,22 @@ def __init__(self, logger.debug(f"sacct returned retcode={retcode} stderr={stderr}") if retcode == 0: logger.debug("using sacct to get job status") + _cmd = "sacct" + # Add clusters option to sacct if provided + if self.clusters: + _cmd += f" --clusters={self.clusters}" # Using state%20 to get enough characters to not truncate output # of the state. Without output can look like " CANCELLED+" - self._cmd = "sacct -X --noheader --format=jobid,state%20 --job '{0}'" + self._cmd = _cmd + " -X --noheader --format=jobid,state%20 --job '{0}'" self._translate_table = sacct_translate_table else: logger.debug(f"sacct failed with retcode={retcode}") logger.debug("falling back to using squeue to get job status") - self._cmd = "squeue --noheader --format='%i %t' --job '{0}'" + _cmd = "squeue" + # Add clusters option to squeue if provided + if self.clusters: + _cmd += f" --clusters={self.clusters}" + self._cmd = _cmd + " --noheader --format='%i %t' --job '{0}'" self._translate_table = squeue_translate_table def _status(self): @@ -334,7 +349,14 @@ def cancel(self, job_ids): ''' job_id_list = ' '.join(job_ids) - retcode, stdout, stderr = self.execute_wait("scancel {0}".format(job_id_list)) + + # Make the command to cancel jobs + _cmd = "scancel" + if self.clusters: + _cmd += f" --clusters={self.clusters}" + _cmd += " {0}" + + retcode, stdout, stderr = self.execute_wait(_cmd.format(job_id_list)) rets = None if retcode == 0: for jid in job_ids: diff --git a/parsl/serialize/facade.py b/parsl/serialize/facade.py index f8e76f174b..2e02e2b983 100644 --- a/parsl/serialize/facade.py +++ b/parsl/serialize/facade.py @@ -87,16 +87,16 @@ def pack_res_spec_apply_message(func: Any, args: Any, kwargs: Any, resource_spec return pack_apply_message(func, args, (kwargs, resource_specification), buffer_threshold=buffer_threshold) -def unpack_apply_message(packed_buffer: bytes, user_ns: Any = None, copy: Any = False) -> List[Any]: +def unpack_apply_message(packed_buffer: bytes) -> List[Any]: """ Unpack and deserialize function and parameters """ return [deserialize(buf) for buf in unpack_buffers(packed_buffer)] -def unpack_res_spec_apply_message(packed_buffer: bytes, user_ns: Any = None, copy: Any = False) -> List[Any]: +def unpack_res_spec_apply_message(packed_buffer: bytes) -> List[Any]: """ Unpack and deserialize function, parameters, and resource_specification """ - func, args, (kwargs, resource_spec) = unpack_apply_message(packed_buffer, user_ns=user_ns, copy=copy) + func, args, (kwargs, resource_spec) = unpack_apply_message(packed_buffer) return [func, args, kwargs, resource_spec] diff --git a/parsl/tests/test_execute_task.py b/parsl/tests/test_execute_task.py new file mode 100644 index 0000000000..42fb59c5c1 --- /dev/null +++ b/parsl/tests/test_execute_task.py @@ -0,0 +1,29 @@ +import os + +import pytest + +from parsl.executors.execute_task import execute_task +from parsl.serialize.facade import pack_res_spec_apply_message + + +def addemup(*args: int, name: str = "apples"): + total = sum(args) + return f"{total} {name}" + + +@pytest.mark.local +def test_execute_task(): + args = (1, 2, 3) + kwargs = {"name": "boots"} + buff = pack_res_spec_apply_message(addemup, args, kwargs, {}) + res = execute_task(buff) + assert res == addemup(*args, **kwargs) + + +@pytest.mark.local +def test_execute_task_resource_spec(): + resource_spec = {"num_nodes": 2, "ranks_per_node": 2, "num_ranks": 4} + buff = pack_res_spec_apply_message(addemup, (1, 2), {}, resource_spec) + execute_task(buff) + for key, val in resource_spec.items(): + assert os.environ[f"PARSL_{key.upper()}"] == str(val)