diff --git a/.github/workflows/parsl+flux.yaml b/.github/workflows/parsl+flux.yaml new file mode 100644 index 0000000000..8b8c43d8b2 --- /dev/null +++ b/.github/workflows/parsl+flux.yaml @@ -0,0 +1,42 @@ +name: Test Flux Scheduler +on: + pull_request: [] + +jobs: + build: + runs-on: ubuntu-22.04 + permissions: + packages: read + strategy: + fail-fast: false + matrix: + container: ['fluxrm/flux-sched:jammy'] + timeout-minutes: 30 + + container: + image: ${{ matrix.container }} + options: "--platform=linux/amd64 --user root -it --init" + + name: ${{ matrix.container }} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Install Dependencies and Parsl + run: | + apt-get update && apt-get install -y python3-pip curl + pip3 install . -r test-requirements.txt + + - name: Verify Parsl Installation + run: | + pytest parsl/tests/ -k "not cleannet and not unix_filesystem_permissions_required" --config parsl/tests/configs/local_threads.py --random-order --durations 10 + + - name: Test Parsl with Flux + run: | + pytest parsl/tests/test_flux.py --config local --random-order + + - name: Test Parsl with Flux Config + run: | + pytest parsl/tests/ -k "not cleannet and not unix_filesystem_permissions_required" --config parsl/tests/configs/flux_local.py --random-order --durations 10 + + diff --git a/Makefile b/Makefile index 0d368f4c59..90f20601e9 100644 --- a/Makefile +++ b/Makefile @@ -127,3 +127,8 @@ coverage: ## show the coverage report .PHONY: clean clean: ## clean up the environment by deleting the .venv, dist, eggs, mypy caches, coverage info, etc rm -rf .venv $(DEPS) dist *.egg-info .mypy_cache build .pytest_cache .coverage runinfo $(WORKQUEUE_INSTALL) + +.PHONY: flux_local_test +flux_local_test: ## Test Parsl with Flux Executor + pip3 install . + pytest parsl/tests/ -k "not cleannet" --config parsl/tests/configs/flux_local.py --random-order --durations 10 diff --git a/README.rst b/README.rst index fb1070e7d7..72048d39f4 100644 --- a/README.rst +++ b/README.rst @@ -109,7 +109,7 @@ For Developers 3. Install:: - $ cd parsl + $ cd parsl # only if you didn't enter the top-level directory in step 2 above $ python3 setup.py install 4. Use Parsl! diff --git a/docs/faq.rst b/docs/faq.rst index f427db82f9..a03287c378 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -13,6 +13,7 @@ Alternatively, you can configure the file logger to write to an output file. .. code-block:: python + import logging import parsl # Emit log lines to the screen diff --git a/docs/historical/changelog.rst b/docs/historical/changelog.rst index 18fe6ca5b1..931998f93d 100644 --- a/docs/historical/changelog.rst +++ b/docs/historical/changelog.rst @@ -334,7 +334,7 @@ New Functionality * New launcher: `parsl.launchers.WrappedLauncher` for launching tasks inside containers. -* `parsl.channels.SSHChannel` now supports a ``key_filename`` kwarg `issue#1639 `_ +* ``parsl.channels.SSHChannel`` now supports a ``key_filename`` kwarg `issue#1639 `_ * Newly added Makefile wraps several frequent developer operations such as: @@ -442,7 +442,7 @@ New Functionality module, parsl.data_provider.globus * `parsl.executors.WorkQueueExecutor`: a new executor that integrates functionality from `Work Queue `_ is now available. -* New provider to support for Ad-Hoc clusters `parsl.providers.AdHocProvider` +* New provider to support for Ad-Hoc clusters ``parsl.providers.AdHocProvider`` * New provider added to support LSF on Summit `parsl.providers.LSFProvider` * Support for CPU and Memory resource hints to providers `(github) `_. * The ``logging_level=logging.INFO`` in `parsl.monitoring.MonitoringHub` is replaced with ``monitoring_debug=False``: @@ -468,7 +468,7 @@ New Functionality * Several test-suite improvements that have dramatically reduced test duration. * Several improvements to the Monitoring interface. -* Configurable port on `parsl.channels.SSHChannel`. +* Configurable port on ``parsl.channels.SSHChannel``. * ``suppress_failure`` now defaults to True. * `parsl.executors.HighThroughputExecutor` is the recommended executor, and ``IPyParallelExecutor`` is deprecated. * `parsl.executors.HighThroughputExecutor` will expose worker information via environment variables: ``PARSL_WORKER_RANK`` and ``PARSL_WORKER_COUNT`` @@ -532,7 +532,7 @@ New Functionality * Cleaner user app file log management. * Updated configurations using `parsl.executors.HighThroughputExecutor` in the configuration section of the userguide. -* Support for OAuth based SSH with `parsl.channels.OAuthSSHChannel`. +* Support for OAuth based SSH with ``parsl.channels.OAuthSSHChannel``. Bug Fixes ^^^^^^^^^ diff --git a/docs/index.rst b/docs/index.rst index 980cf598f8..65696ec048 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,8 +23,8 @@ Parsl lets you chain functions together and will launch each function as inputs return x + 1 @python_app - def g(x): - return x * 2 + def g(x, y): + return x + y # These functions now return Futures, and can be chained future = f(1) diff --git a/docs/reference.rst b/docs/reference.rst index 1af850792c..d8e18bd244 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -38,15 +38,9 @@ Configuration Channels ======== -.. autosummary:: - :toctree: stubs - :nosignatures: - - parsl.channels.base.Channel - parsl.channels.LocalChannel - parsl.channels.SSHChannel - parsl.channels.OAuthSSHChannel - parsl.channels.SSHInteractiveLoginChannel +Channels are deprecated in Parsl. See +`issue 3515 `_ +for further discussion. Data management =============== @@ -109,7 +103,6 @@ Providers :toctree: stubs :nosignatures: - parsl.providers.AdHocProvider parsl.providers.AWSProvider parsl.providers.CobaltProvider parsl.providers.CondorProvider diff --git a/docs/userguide/checkpoints.rst b/docs/userguide/checkpoints.rst index dbcfcfc760..8867107b7a 100644 --- a/docs/userguide/checkpoints.rst +++ b/docs/userguide/checkpoints.rst @@ -49,15 +49,17 @@ during development. Using app caching will ensure that only modified apps are re App equivalence ^^^^^^^^^^^^^^^ -Parsl determines app equivalence by storing the hash -of the app function. Thus, any changes to the app code (e.g., -its signature, its body, or even the docstring within the body) -will invalidate cached values. +Parsl determines app equivalence using the name of the app function: +if two apps have the same name, then they are equivalent under this +relation. -However, Parsl does not traverse the call graph of the app function, -so changes inside functions called by an app will not invalidate +Changes inside the app, or by functions called by an app will not invalidate cached values. +There are lots of other ways functions might be compared for equivalence, +and `parsl.dataflow.memoization.id_for_memo` provides a hook to plug in +alternate application-specific implementations. + Invocation equivalence ^^^^^^^^^^^^^^^^^^^^^^ @@ -92,7 +94,7 @@ Attempting to cache apps invoked with other, non-hashable, data types will lead to an exception at invocation. In that case, mechanisms to hash new types can be registered by a program by -implementing the ``parsl.dataflow.memoization.id_for_memo`` function for +implementing the `parsl.dataflow.memoization.id_for_memo` function for the new type. Ignoring arguments diff --git a/docs/userguide/configuring.rst b/docs/userguide/configuring.rst index b4165411dd..f3fe5cc407 100644 --- a/docs/userguide/configuring.rst +++ b/docs/userguide/configuring.rst @@ -15,7 +15,7 @@ queues, durations, and data management options. The following example shows a basic configuration object (:class:`~parsl.config.Config`) for the Frontera supercomputer at TACC. This config uses the `parsl.executors.HighThroughputExecutor` to submit -tasks from a login node (`parsl.channels.LocalChannel`). It requests an allocation of +tasks from a login node. It requests an allocation of 128 nodes, deploying 1 worker for each of the 56 cores per node, from the normal partition. To limit network connections to just the internal network the config specifies the address used by the infiniband interface with ``address_by_interface('ib0')`` @@ -23,7 +23,6 @@ used by the infiniband interface with ``address_by_interface('ib0')`` .. code-block:: python from parsl.config import Config - from parsl.channels import LocalChannel from parsl.providers import SlurmProvider from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher @@ -36,7 +35,6 @@ used by the infiniband interface with ``address_by_interface('ib0')`` address=address_by_interface('ib0'), max_workers_per_node=56, provider=SlurmProvider( - channel=LocalChannel(), nodes_per_block=128, init_blocks=1, partition='normal', @@ -197,22 +195,6 @@ Stepping through the following question should help formulate a suitable configu are on a **native Slurm** system like :ref:`configuring_nersc_cori` -4) Where will the main Parsl program run and how will it communicate with the apps? - -+------------------------+--------------------------+---------------------------------------------------+ -| Parsl program location | App execution target | Suitable channel | -+========================+==========================+===================================================+ -| Laptop/Workstation | Laptop/Workstation | `parsl.channels.LocalChannel` | -+------------------------+--------------------------+---------------------------------------------------+ -| Laptop/Workstation | Cloud Resources | No channel is needed | -+------------------------+--------------------------+---------------------------------------------------+ -| Laptop/Workstation | Clusters with no 2FA | `parsl.channels.SSHChannel` | -+------------------------+--------------------------+---------------------------------------------------+ -| Laptop/Workstation | Clusters with 2FA | `parsl.channels.SSHInteractiveLoginChannel` | -+------------------------+--------------------------+---------------------------------------------------+ -| Login node | Cluster/Supercomputer | `parsl.channels.LocalChannel` | -+------------------------+--------------------------+---------------------------------------------------+ - Heterogeneous Resources ----------------------- @@ -324,9 +306,13 @@ and Work Queue does not require Python to run. Accelerators ------------ -Many modern clusters provide multiple accelerators per compute note, yet many applications are best suited to using a single accelerator per task. -Parsl supports pinning each worker to difference accelerators using ``available_accelerators`` option of the :class:`~parsl.executors.HighThroughputExecutor`. -Provide either the number of executors (Parsl will assume they are named in integers starting from zero) or a list of the names of the accelerators available on the node. +Many modern clusters provide multiple accelerators per compute note, yet many applications are best suited to using a +single accelerator per task. Parsl supports pinning each worker to different accelerators using +``available_accelerators`` option of the :class:`~parsl.executors.HighThroughputExecutor`. Provide either the number of +executors (Parsl will assume they are named in integers starting from zero) or a list of the names of the accelerators +available on the node. Parsl will limit the number of workers it launches to the number of accelerators specified, +in other words, you cannot have more workers per node than there are accelerators. By default, Parsl will launch +as many workers as the accelerators specified via ``available_accelerators``. .. code-block:: python @@ -337,7 +323,6 @@ Provide either the number of executors (Parsl will assume they are named in inte worker_debug=True, available_accelerators=2, provider=LocalProvider( - channel=LocalChannel(), init_blocks=1, max_blocks=1, ), @@ -346,7 +331,39 @@ Provide either the number of executors (Parsl will assume they are named in inte strategy='none', ) +It is possible to bind multiple/specific accelerators to each worker by specifying a list of comma separated strings +each specifying accelerators. In the context of binding to NVIDIA GPUs, this works by setting ``CUDA_VISIBLE_DEVICES`` +on each worker to a specific string in the list supplied to ``available_accelerators``. + +Here's an example: + +.. code-block:: python + + # The following config is trimmed for clarity + local_config = Config( + executors=[ + HighThroughputExecutor( + # Starts 2 workers per node, each bound to 2 GPUs + available_accelerators=["0,1", "2,3"], + + # Start a single worker bound to all 4 GPUs + # available_accelerators=["0,1,2,3"] + ) + ], + ) +GPU Oversubscription +"""""""""""""""""""" + +For hardware that uses Nvidia devices, Parsl allows for the oversubscription of workers to GPUS. This is intended to +make use of Nvidia's `Multi-Process Service (MPS) `_ available on many of their +GPUs that allows users to run multiple concurrent processes on a single GPU. The user needs to set in the +``worker_init`` commands to start MPS on every node in the block (this is machine dependent). The +``available_accelerators`` option should then be set to the total number of GPU partitions run on a single node in the +block. For example, for a node with 4 Nvidia GPUs, to create 8 workers per GPU, set ``available_accelerators=32``. +GPUs will be assigned to workers in ascending order in contiguous blocks. In the example, workers 0-7 will be placed +on GPU 0, workers 8-15 on GPU 1, workers 16-23 on GPU 2, and workers 24-31 on GPU 3. + Multi-Threaded Applications --------------------------- @@ -371,7 +388,6 @@ Select the best blocking strategy for processor's cache hierarchy (choose ``alte worker_debug=True, cpu_affinity='alternating', provider=LocalProvider( - channel=LocalChannel(), init_blocks=1, max_blocks=1, ), @@ -411,18 +427,12 @@ These include ``OMP_NUM_THREADS``, ``GOMP_COMP_AFFINITY``, and ``KMP_THREAD_AFFI Ad-Hoc Clusters --------------- -Any collection of compute nodes without a scheduler can be considered an -ad-hoc cluster. Often these machines have a shared file system such as NFS or Lustre. -In order to use these resources with Parsl, they need to set-up for password-less SSH access. - -To use these ssh-accessible collection of nodes as an ad-hoc cluster, we use -the `parsl.providers.AdHocProvider` with an `parsl.channels.SSHChannel` to each node. An example -configuration follows. +Parsl's support of ad-hoc clusters of compute nodes without a scheduler +is deprecated. -.. literalinclude:: ../../parsl/configs/ad_hoc.py - -.. note:: - Multiple blocks should not be assigned to each node when using the `parsl.executors.HighThroughputExecutor` +See +`issue #3515 `_ +for further discussion. Amazon Web Services ------------------- diff --git a/docs/userguide/examples/config.py b/docs/userguide/examples/config.py index 166faaf4ac..68057d2b01 100644 --- a/docs/userguide/examples/config.py +++ b/docs/userguide/examples/config.py @@ -1,4 +1,3 @@ -from parsl.channels import LocalChannel from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import LocalProvider @@ -8,9 +7,7 @@ HighThroughputExecutor( label="htex_local", cores_per_worker=1, - provider=LocalProvider( - channel=LocalChannel(), - ), + provider=LocalProvider(), ) ], ) diff --git a/docs/userguide/execution.rst b/docs/userguide/execution.rst index 4168367f9d..df17dc458f 100644 --- a/docs/userguide/execution.rst +++ b/docs/userguide/execution.rst @@ -47,8 +47,7 @@ Parsl currently supports the following providers: 7. `parsl.providers.AWSProvider`: This provider allows you to provision and manage cloud nodes from Amazon Web Services. 8. `parsl.providers.GoogleCloudProvider`: This provider allows you to provision and manage cloud nodes from Google Cloud. 9. `parsl.providers.KubernetesProvider`: This provider allows you to provision and manage containers on a Kubernetes cluster. -10. `parsl.providers.AdHocProvider`: This provider allows you manage execution over a collection of nodes to form an ad-hoc cluster. -11. `parsl.providers.LSFProvider`: This provider allows you to schedule resources via IBM's LSF scheduler. +10. `parsl.providers.LSFProvider`: This provider allows you to schedule resources via IBM's LSF scheduler. diff --git a/docs/userguide/mpi_apps.rst b/docs/userguide/mpi_apps.rst index a40c03e004..82123123b6 100644 --- a/docs/userguide/mpi_apps.rst +++ b/docs/userguide/mpi_apps.rst @@ -60,6 +60,13 @@ An example for ALCF's Polaris supercomputer that will run 3 MPI tasks of 2 nodes ) +.. warning:: + Please note that ``Provider`` options that specify per-task or per-node resources, for example, + ``SlurmProvider(cores_per_node=N, ...)`` should not be used with :class:`~parsl.executors.high_throughput.MPIExecutor`. + Parsl primarily uses a pilot job model and assumptions from that context do not translate to the MPI context. For + more info refer to : + `github issue #3006 `_ + Writing an MPI App ------------------ diff --git a/docs/userguide/plugins.rst b/docs/userguide/plugins.rst index 4ecff86cfe..c3c38dea63 100644 --- a/docs/userguide/plugins.rst +++ b/docs/userguide/plugins.rst @@ -16,8 +16,8 @@ executor to run code on the local submitting host, while another executor can run the same code on a large supercomputer. -Providers, Launchers and Channels ---------------------------------- +Providers and Launchers +----------------------- Some executors are based on blocks of workers (for example the `parsl.executors.HighThroughputExecutor`: the submit side requires a batch system (eg slurm, kubernetes) to start worker processes, which then @@ -34,10 +34,9 @@ add on any wrappers that are needed to launch the command (eg srun inside slurm). Providers and launchers are usually paired together for a particular system type. -A `Channel` allows the commands used to interact with an `ExecutionProvider` to be -executed on a remote system. The default channel executes commands on the -local system, but a few variants of an `parsl.channels.SSHChannel` are provided. - +Parsl also has a deprecated ``Channel`` abstraction. See +`issue 3515 `_ +for further discussion. File staging ------------ diff --git a/parsl/app/app.py b/parsl/app/app.py index 6097415c9e..8d0d829b33 100644 --- a/parsl/app/app.py +++ b/parsl/app/app.py @@ -66,8 +66,10 @@ def __init__(self, func: Callable, self.kwargs['walltime'] = params['walltime'].default if 'parsl_resource_specification' in params: self.kwargs['parsl_resource_specification'] = params['parsl_resource_specification'].default - self.outputs = params['outputs'].default if 'outputs' in params else [] - self.inputs = params['inputs'].default if 'inputs' in params else [] + if 'outputs' in params: + self.kwargs['outputs'] = params['outputs'].default + if 'inputs' in params: + self.kwargs['inputs'] = params['inputs'].default @abstractmethod def __call__(self, *args: Any, **kwargs: Any) -> AppFuture: diff --git a/parsl/app/bash.py b/parsl/app/bash.py index 4ab0add68b..36212c172f 100644 --- a/parsl/app/bash.py +++ b/parsl/app/bash.py @@ -1,5 +1,5 @@ import logging -from functools import partial, update_wrapper +from functools import partial from inspect import Parameter, signature from parsl.app.app import AppBase @@ -123,11 +123,10 @@ def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ig if sig.parameters[s].default is not Parameter.empty: self.kwargs[s] = sig.parameters[s].default - # update_wrapper allows remote_side_bash_executor to masquerade as self.func # partial is used to attach the first arg the "func" to the remote_side_bash_executor # this is done to avoid passing a function type in the args which parsl.serializer # doesn't support - remote_fn = partial(update_wrapper(remote_side_bash_executor, self.func), self.func) + remote_fn = partial(remote_side_bash_executor, self.func) remote_fn.__name__ = self.func.__name__ self.wrapped_remote_function = wrap_error(remote_fn) diff --git a/parsl/channels/__init__.py b/parsl/channels/__init__.py index 5a45d15278..c81f6a8bf1 100644 --- a/parsl/channels/__init__.py +++ b/parsl/channels/__init__.py @@ -1,7 +1,4 @@ from parsl.channels.base import Channel from parsl.channels.local.local import LocalChannel -from parsl.channels.oauth_ssh.oauth_ssh import OAuthSSHChannel -from parsl.channels.ssh.ssh import SSHChannel -from parsl.channels.ssh_il.ssh_il import SSHInteractiveLoginChannel -__all__ = ['Channel', 'SSHChannel', 'LocalChannel', 'SSHInteractiveLoginChannel', 'OAuthSSHChannel'] +__all__ = ['Channel', 'LocalChannel'] diff --git a/parsl/channels/local/local.py b/parsl/channels/local/local.py index 537f64a0c3..b94629095e 100644 --- a/parsl/channels/local/local.py +++ b/parsl/channels/local/local.py @@ -55,6 +55,7 @@ def execute_wait(self, cmd, walltime=None, envs={}): current_env.update(envs) try: + logger.debug("Creating process with command '%s'", cmd) proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, @@ -64,12 +65,16 @@ def execute_wait(self, cmd, walltime=None, envs={}): shell=True, preexec_fn=os.setpgrp ) + logger.debug("Created process with pid %s. Performing communicate", proc.pid) (stdout, stderr) = proc.communicate(timeout=walltime) retcode = proc.returncode + logger.debug("Process %s returned %s", proc.pid, proc.returncode) - except Exception as e: - logger.warning("Execution of command '{}' failed due to \n{}".format(cmd, e)) + except Exception: + logger.exception(f"Execution of command failed:\n{cmd}") raise + else: + logger.debug("Execution of command in process %s completed normally", proc.pid) return (retcode, stdout.decode("utf-8"), stderr.decode("utf-8")) diff --git a/parsl/channels/oauth_ssh/oauth_ssh.py b/parsl/channels/oauth_ssh/oauth_ssh.py index c9efa27767..1b690a4e3c 100644 --- a/parsl/channels/oauth_ssh/oauth_ssh.py +++ b/parsl/channels/oauth_ssh/oauth_ssh.py @@ -1,11 +1,15 @@ import logging import socket -import paramiko - -from parsl.channels.ssh.ssh import SSHChannel +from parsl.channels.ssh.ssh import DeprecatedSSHChannel from parsl.errors import OptionalModuleMissing +try: + import paramiko + _ssh_enabled = True +except (ImportError, NameError, FileNotFoundError): + _ssh_enabled = False + try: from oauth_ssh.oauth_ssh_token import find_access_token from oauth_ssh.ssh_service import SSHService @@ -17,7 +21,7 @@ logger = logging.getLogger(__name__) -class OAuthSSHChannel(SSHChannel): +class DeprecatedOAuthSSHChannel(DeprecatedSSHChannel): """SSH persistent channel. This enables remote execution on sites accessible via ssh. This channel uses Globus based OAuth tokens for authentication. """ @@ -38,6 +42,10 @@ def __init__(self, hostname, username=None, script_dir=None, envs=None, port=22) Raises: ''' + if not _ssh_enabled: + raise OptionalModuleMissing(['ssh'], + "OauthSSHChannel requires the ssh module and config.") + if not _oauth_ssh_enabled: raise OptionalModuleMissing(['oauth_ssh'], "OauthSSHChannel requires oauth_ssh module and config.") diff --git a/parsl/channels/ssh/ssh.py b/parsl/channels/ssh/ssh.py index 6b38ed68e6..c53a26b831 100644 --- a/parsl/channels/ssh/ssh.py +++ b/parsl/channels/ssh/ssh.py @@ -2,8 +2,6 @@ import logging import os -import paramiko - from parsl.channels.base import Channel from parsl.channels.errors import ( AuthException, @@ -13,18 +11,27 @@ FileCopyException, SSHException, ) +from parsl.errors import OptionalModuleMissing from parsl.utils import RepresentationMixin +try: + import paramiko + _ssh_enabled = True +except (ImportError, NameError, FileNotFoundError): + _ssh_enabled = False + + logger = logging.getLogger(__name__) -class NoAuthSSHClient(paramiko.SSHClient): - def _auth(self, username, *args): - self._transport.auth_none(username) - return +if _ssh_enabled: + class NoAuthSSHClient(paramiko.SSHClient): + def _auth(self, username, *args): + self._transport.auth_none(username) + return -class SSHChannel(Channel, RepresentationMixin): +class DeprecatedSSHChannel(Channel, RepresentationMixin): ''' SSH persistent channel. This enables remote execution on sites accessible via ssh. It is assumed that the user has setup host keys so as to ssh to the remote host. Which goes to say that the following @@ -53,6 +60,9 @@ def __init__(self, hostname, username=None, password=None, script_dir=None, envs Raises: ''' + if not _ssh_enabled: + raise OptionalModuleMissing(['ssh'], + "SSHChannel requires the ssh module and config.") self.hostname = hostname self.username = username @@ -227,8 +237,20 @@ def pull_file(self, remote_source, local_dir): def close(self) -> None: if self._is_connected(): + transport = self.ssh_client.get_transport() self.ssh_client.close() + # ssh_client.close calls transport.close, but transport.close does + # not always wait for the transport thread to be stopped. See impl + # of Transport.close in paramiko and issue + # https://github.com/paramiko/paramiko/issues/520 + logger.debug("Waiting for transport thread to stop") + transport.join(30) + if transport.is_alive(): + logger.warning("SSH transport thread did not shut down") + else: + logger.debug("SSH transport thread stopped") + def isdir(self, path): """Return true if the path refers to an existing directory. diff --git a/parsl/channels/ssh_il/ssh_il.py b/parsl/channels/ssh_il/ssh_il.py index 02e7a58cd4..67e5501a43 100644 --- a/parsl/channels/ssh_il/ssh_il.py +++ b/parsl/channels/ssh_il/ssh_il.py @@ -1,14 +1,20 @@ import getpass import logging -import paramiko +from parsl.channels.ssh.ssh import DeprecatedSSHChannel +from parsl.errors import OptionalModuleMissing + +try: + import paramiko + _ssh_enabled = True +except (ImportError, NameError, FileNotFoundError): + _ssh_enabled = False -from parsl.channels.ssh.ssh import SSHChannel logger = logging.getLogger(__name__) -class SSHInteractiveLoginChannel(SSHChannel): +class DeprecatedSSHInteractiveLoginChannel(DeprecatedSSHChannel): """SSH persistent channel. This enables remote execution on sites accessible via ssh. This channel supports interactive login and is appropriate when keys are not set up. @@ -30,6 +36,10 @@ def __init__(self, hostname, username=None, password=None, script_dir=None, envs Raises: ''' + if not _ssh_enabled: + raise OptionalModuleMissing(['ssh'], + "SSHInteractiveLoginChannel requires the ssh module and config.") + self.hostname = hostname self.username = username self.password = password diff --git a/parsl/config.py b/parsl/config.py index ecea149114..c3725eccf8 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -40,6 +40,15 @@ class Config(RepresentationMixin, UsageInformation): ``checkpoint_mode='periodic'``. dependency_resolver: plugin point for custom dependency resolvers. Default: only resolve Futures, using the `SHALLOW_DEPENDENCY_RESOLVER`. + exit_mode: str, optional + When Parsl is used as a context manager (using ``with parsl.load`` syntax) then this parameter + controls what will happen to running tasks and exceptions at exit. The options are: + + * ``cleanup``: cleanup the DFK on exit without waiting for any tasks + * ``skip``: skip all shutdown behaviour when exiting the context manager + * ``wait``: wait for all tasks to complete when exiting normally, but exit immediately when exiting due to an exception. + + Default is ``cleanup``. garbage_collect : bool. optional. Delete task records from DFK when tasks have completed. Default: True internal_tasks_max_threads : int, optional @@ -97,6 +106,7 @@ def __init__(self, Literal['manual']] = None, checkpoint_period: Optional[str] = None, dependency_resolver: Optional[DependencyResolver] = None, + exit_mode: Literal['cleanup', 'skip', 'wait'] = 'cleanup', garbage_collect: bool = True, internal_tasks_max_threads: int = 10, retries: int = 0, @@ -133,6 +143,7 @@ def __init__(self, checkpoint_period = "00:30:00" self.checkpoint_period = checkpoint_period self.dependency_resolver = dependency_resolver + self.exit_mode = exit_mode self.garbage_collect = garbage_collect self.internal_tasks_max_threads = internal_tasks_max_threads self.retries = retries diff --git a/parsl/configs/ASPIRE1.py b/parsl/configs/ASPIRE1.py index 1b502fadaf..7792f15dba 100644 --- a/parsl/configs/ASPIRE1.py +++ b/parsl/configs/ASPIRE1.py @@ -4,6 +4,7 @@ from parsl.launchers import MpiRunLauncher from parsl.monitoring.monitoring import MonitoringHub from parsl.providers import PBSProProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -39,5 +40,6 @@ strategy='simple', retries=3, app_cache=True, - checkpoint_mode='task_exit' + checkpoint_mode='task_exit', + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/Azure.py b/parsl/configs/Azure.py index 9d05be7940..2a27db3f1b 100644 --- a/parsl/configs/Azure.py +++ b/parsl/configs/Azure.py @@ -8,6 +8,7 @@ from parsl.data_provider.rsync import RSyncStaging from parsl.executors import HighThroughputExecutor from parsl.providers import AzureProvider +from parsl.usage_tracking.levels import LEVEL_1 vm_reference = { # All fields below are required @@ -33,5 +34,6 @@ FTPInTaskStaging(), RSyncStaging(getpass.getuser() + "@" + address_by_query())], ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/ad_hoc.py b/parsl/configs/ad_hoc.py deleted file mode 100644 index daee13ea00..0000000000 --- a/parsl/configs/ad_hoc.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Any, Dict - -from parsl.channels import SSHChannel -from parsl.config import Config -from parsl.executors import HighThroughputExecutor -from parsl.providers import AdHocProvider - -user_opts: Dict[str, Dict[str, Any]] -user_opts = {'adhoc': - {'username': 'YOUR_USERNAME', - 'script_dir': 'YOUR_SCRIPT_DIR', - 'remote_hostnames': ['REMOTE_HOST_URL_1', 'REMOTE_HOST_URL_2'] - } - } - - -config = Config( - executors=[ - HighThroughputExecutor( - label='remote_htex', - max_workers_per_node=2, - worker_logdir_root=user_opts['adhoc']['script_dir'], - provider=AdHocProvider( - # Command to be run before starting a worker, such as: - # 'module load Anaconda; source activate parsl_env'. - worker_init='', - channels=[SSHChannel(hostname=m, - username=user_opts['adhoc']['username'], - script_dir=user_opts['adhoc']['script_dir'], - ) for m in user_opts['adhoc']['remote_hostnames']] - ) - ) - ], - # AdHoc Clusters should not be setup with scaling strategy. - strategy='none', -) diff --git a/parsl/configs/bridges.py b/parsl/configs/bridges.py index 928cd70549..4cb0fba543 100644 --- a/parsl/configs/bridges.py +++ b/parsl/configs/bridges.py @@ -3,6 +3,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 """ This config assumes that it is used to launch parsl tasks from the login nodes of Bridges at PSC. Each job submitted to the scheduler will request 2 nodes for 10 minutes. @@ -34,5 +35,6 @@ cmd_timeout=120, ), ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/cc_in2p3.py b/parsl/configs/cc_in2p3.py index 4016977aed..631d76f9f5 100644 --- a/parsl/configs/cc_in2p3.py +++ b/parsl/configs/cc_in2p3.py @@ -2,6 +2,7 @@ from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import GridEngineProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -19,4 +20,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/ec2.py b/parsl/configs/ec2.py index efe2afcfe8..8e85252acc 100644 --- a/parsl/configs/ec2.py +++ b/parsl/configs/ec2.py @@ -1,6 +1,7 @@ from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import AWSProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -25,4 +26,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/expanse.py b/parsl/configs/expanse.py index e8f5db9cb7..35ef5e0fa2 100644 --- a/parsl/configs/expanse.py +++ b/parsl/configs/expanse.py @@ -2,6 +2,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -24,5 +25,6 @@ nodes_per_block=2, ), ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/frontera.py b/parsl/configs/frontera.py index 1aa4639bea..a7b6f27b6c 100644 --- a/parsl/configs/frontera.py +++ b/parsl/configs/frontera.py @@ -3,6 +3,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 """ This config assumes that it is used to launch parsl tasks from the login nodes of Frontera at TACC. Each job submitted to the scheduler will request 2 nodes for 10 minutes. @@ -32,4 +33,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/htex_local.py b/parsl/configs/htex_local.py index da34f59f81..721dea767e 100644 --- a/parsl/configs/htex_local.py +++ b/parsl/configs/htex_local.py @@ -2,6 +2,7 @@ from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import LocalProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -15,4 +16,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/illinoiscluster.py b/parsl/configs/illinoiscluster.py index 3f3585d3b6..216c910b56 100644 --- a/parsl/configs/illinoiscluster.py +++ b/parsl/configs/illinoiscluster.py @@ -2,6 +2,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 """ This config assumes that it is used to launch parsl tasks from the login nodes of the Campus Cluster at UIUC. Each job submitted to the scheduler will request 2 nodes for 10 minutes. @@ -25,4 +26,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/kubernetes.py b/parsl/configs/kubernetes.py index 829f3b81c3..5a4601862b 100644 --- a/parsl/configs/kubernetes.py +++ b/parsl/configs/kubernetes.py @@ -2,6 +2,7 @@ from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import KubernetesProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -36,5 +37,6 @@ max_blocks=10, ), ), - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/local_threads.py b/parsl/configs/local_threads.py index f02e1f1e15..6b6561ea62 100644 --- a/parsl/configs/local_threads.py +++ b/parsl/configs/local_threads.py @@ -1,4 +1,8 @@ from parsl.config import Config from parsl.executors.threads import ThreadPoolExecutor +from parsl.usage_tracking.levels import LEVEL_1 -config = Config(executors=[ThreadPoolExecutor()]) +config = Config( + executors=[ThreadPoolExecutor()], + usage_tracking=LEVEL_1, +) diff --git a/parsl/configs/midway.py b/parsl/configs/midway.py index 251eb419b1..960c406cfe 100644 --- a/parsl/configs/midway.py +++ b/parsl/configs/midway.py @@ -3,6 +3,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -28,4 +29,5 @@ ), ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/osg.py b/parsl/configs/osg.py index 016d40630d..bd0c04ad56 100644 --- a/parsl/configs/osg.py +++ b/parsl/configs/osg.py @@ -1,6 +1,7 @@ from parsl.config import Config from parsl.executors import HighThroughputExecutor from parsl.providers import CondorProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -26,5 +27,6 @@ worker_logdir_root='$OSG_WN_TMP', worker_ports=(31000, 31001) ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/polaris.py b/parsl/configs/polaris.py index 3c6b96959d..3d59991d96 100644 --- a/parsl/configs/polaris.py +++ b/parsl/configs/polaris.py @@ -3,6 +3,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiExecLauncher from parsl.providers import PBSProProvider +from parsl.usage_tracking.levels import LEVEL_1 # There are three user parameters to change for the PBSProProvider: # YOUR_ACCOUNT: Account to charge usage @@ -34,5 +35,6 @@ cpus_per_node=64, ), ), - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/stampede2.py b/parsl/configs/stampede2.py index 0ffb0e3314..b8e2aca9b9 100644 --- a/parsl/configs/stampede2.py +++ b/parsl/configs/stampede2.py @@ -4,6 +4,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -34,4 +35,5 @@ ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/summit.py b/parsl/configs/summit.py index 2695f2da7f..11e68ca2c1 100644 --- a/parsl/configs/summit.py +++ b/parsl/configs/summit.py @@ -3,6 +3,7 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import JsrunLauncher from parsl.providers import LSFProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -26,4 +27,5 @@ ) ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/toss3_llnl.py b/parsl/configs/toss3_llnl.py index a7820b3ca4..5c6b1c71c5 100644 --- a/parsl/configs/toss3_llnl.py +++ b/parsl/configs/toss3_llnl.py @@ -2,6 +2,7 @@ from parsl.executors import FluxExecutor from parsl.launchers import SrunLauncher from parsl.providers import SlurmProvider +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -24,5 +25,6 @@ cmd_timeout=120, ), ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/vineex_local.py b/parsl/configs/vineex_local.py index c88d92213c..755f1d1cc4 100644 --- a/parsl/configs/vineex_local.py +++ b/parsl/configs/vineex_local.py @@ -2,6 +2,7 @@ from parsl.config import Config from parsl.executors.taskvine import TaskVineExecutor, TaskVineManagerConfig +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -15,5 +16,6 @@ # To disable status reporting, comment out the project_name. manager_config=TaskVineManagerConfig(project_name="parsl-vine-" + str(uuid.uuid4())), ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/configs/wqex_local.py b/parsl/configs/wqex_local.py index 8a4d570883..fa583f381a 100644 --- a/parsl/configs/wqex_local.py +++ b/parsl/configs/wqex_local.py @@ -2,6 +2,7 @@ from parsl.config import Config from parsl.executors import WorkQueueExecutor +from parsl.usage_tracking.levels import LEVEL_1 config = Config( executors=[ @@ -21,5 +22,6 @@ # A shared filesystem is not needed when using Work Queue. shared_fs=False ) - ] + ], + usage_tracking=LEVEL_1, ) diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index dffa7e52fd..344173c4b1 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -1,6 +1,7 @@ from __future__ import annotations import atexit +import concurrent.futures as cf import datetime import inspect import logging @@ -112,14 +113,10 @@ def __init__(self, config: Config) -> None: self.monitoring: Optional[MonitoringHub] self.monitoring = config.monitoring - # hub address and port for interchange to connect - self.hub_address = None # type: Optional[str] - self.hub_zmq_port = None # type: Optional[int] if self.monitoring: if self.monitoring.logdir is None: self.monitoring.logdir = self.run_dir - self.hub_address = self.monitoring.hub_address - self.hub_zmq_port = self.monitoring.start(self.run_id, self.run_dir, self.config.run_dir) + self.monitoring.start(self.run_dir, self.config.run_dir) self.time_began = datetime.datetime.now() self.time_completed: Optional[datetime.datetime] = None @@ -209,6 +206,8 @@ def __init__(self, config: Config) -> None: self.tasks: Dict[int, TaskRecord] = {} self.submitter_lock = threading.Lock() + self.dependency_launch_pool = cf.ThreadPoolExecutor(max_workers=1, thread_name_prefix="Dependency-Launch") + self.dependency_resolver = self.config.dependency_resolver if self.config.dependency_resolver is not None \ else SHALLOW_DEPENDENCY_RESOLVER @@ -217,9 +216,24 @@ def __init__(self, config: Config) -> None: def __enter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): - logger.debug("Exiting the context manager, calling cleanup for DFK") - self.cleanup() + def __exit__(self, exc_type, exc_value, traceback) -> None: + mode = self.config.exit_mode + logger.debug("Exiting context manager, with exit mode '%s'", mode) + if mode == "cleanup": + logger.info("Calling cleanup for DFK") + self.cleanup() + elif mode == "skip": + logger.info("Skipping all cleanup handling") + elif mode == "wait": + if exc_type is None: + logger.info("Waiting for all tasks to complete") + self.wait_for_current_tasks() + self.cleanup() + else: + logger.info("There was an exception - cleaning up without waiting for task completion") + self.cleanup() + else: + raise InternalConsistencyError(f"Exit case for {mode} should be unreachable, validated by typeguard on Config()") def _send_task_log_info(self, task_record: TaskRecord) -> None: if self.monitoring: @@ -611,9 +625,9 @@ def check_staging_inhibited(kwargs: Dict[str, Any]) -> bool: return kwargs.get('_parsl_staging_inhibit', False) def launch_if_ready(self, task_record: TaskRecord) -> None: - """ - launch_if_ready will launch the specified task, if it is ready - to run (for example, without dependencies, and in pending state). + """Schedules a task record for re-inspection to see if it is ready + for launch and for launch if it is ready. The call will return + immediately. This should be called by any piece of the DataFlowKernel that thinks a task may have become ready to run. @@ -622,13 +636,17 @@ def launch_if_ready(self, task_record: TaskRecord) -> None: ready to run - launch_if_ready will not incorrectly launch that task. - It is also not an error to call launch_if_ready on a task that has - already been launched - launch_if_ready will not re-launch that - task. - launch_if_ready is thread safe, so may be called from any thread or callback. """ + self.dependency_launch_pool.submit(self._launch_if_ready_async, task_record) + + @wrap_with_logs + def _launch_if_ready_async(self, task_record: TaskRecord) -> None: + """ + _launch_if_ready will launch the specified task, if it is ready + to run (for example, without dependencies, and in pending state). + """ exec_fu = None task_id = task_record['id'] @@ -1159,10 +1177,10 @@ def add_executors(self, executors: Sequence[ParslExecutor]) -> None: for executor in executors: executor.run_id = self.run_id executor.run_dir = self.run_dir - executor.hub_address = self.hub_address - executor.hub_zmq_port = self.hub_zmq_port if self.monitoring: - executor.monitoring_radio = self.monitoring.radio + executor.hub_address = self.monitoring.hub_address + executor.hub_zmq_port = self.monitoring.hub_zmq_port + executor.submit_monitoring_radio = self.monitoring.radio if hasattr(executor, 'provider'): if hasattr(executor.provider, 'script_dir'): executor.provider.script_dir = os.path.join(self.run_dir, 'submit_scripts') @@ -1255,6 +1273,23 @@ def cleanup(self) -> None: executor.shutdown() logger.info(f"Shut down executor {executor.label}") + if hasattr(executor, 'provider'): + if hasattr(executor.provider, 'script_dir'): + logger.info(f"Closing channel(s) for {executor.label}") + + if hasattr(executor.provider, 'channels'): + for channel in executor.provider.channels: + logger.info(f"Closing channel {channel}") + channel.close() + logger.info(f"Closed channel {channel}") + else: + assert hasattr(executor.provider, 'channel'), "If provider has no .channels, it must have .channel" + logger.info(f"Closing channel {executor.provider.channel}") + executor.provider.channel.close() + logger.info(f"Closed channel {executor.provider.channel}") + + logger.info(f"Closed executor channel(s) for {executor.label}") + logger.info("Terminated executors") self.time_completed = datetime.datetime.now() @@ -1271,6 +1306,10 @@ def cleanup(self) -> None: self.monitoring.close() logger.info("Terminated monitoring") + logger.info("Terminating dependency launch pool") + self.dependency_launch_pool.shutdown() + logger.info("Terminated dependency launch pool") + logger.info("Unregistering atexit hook") atexit.unregister(self.atexit_cleanup) logger.info("Unregistered atexit hook") @@ -1417,8 +1456,6 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Returns: - dict containing, hashed -> future mappings """ - self.memo_lookup_table = None - if checkpointDirs: return self._load_checkpoints(checkpointDirs) else: diff --git a/parsl/executors/base.py b/parsl/executors/base.py index b00aa55680..a112b9eb00 100644 --- a/parsl/executors/base.py +++ b/parsl/executors/base.py @@ -5,7 +5,7 @@ from typing_extensions import Literal, Self -from parsl.monitoring.radios import MonitoringRadio +from parsl.monitoring.radios import MonitoringRadioSender class ParslExecutor(metaclass=ABCMeta): @@ -52,13 +52,13 @@ def __init__( *, hub_address: Optional[str] = None, hub_zmq_port: Optional[int] = None, - monitoring_radio: Optional[MonitoringRadio] = None, + submit_monitoring_radio: Optional[MonitoringRadioSender] = None, run_dir: str = ".", run_id: Optional[str] = None, ): self.hub_address = hub_address self.hub_zmq_port = hub_zmq_port - self.monitoring_radio = monitoring_radio + self.submit_monitoring_radio = submit_monitoring_radio self.run_dir = os.path.abspath(run_dir) self.run_id = run_id @@ -147,11 +147,11 @@ def hub_zmq_port(self, value: Optional[int]) -> None: self._hub_zmq_port = value @property - def monitoring_radio(self) -> Optional[MonitoringRadio]: + def submit_monitoring_radio(self) -> Optional[MonitoringRadioSender]: """Local radio for sending monitoring messages """ - return self._monitoring_radio + return self._submit_monitoring_radio - @monitoring_radio.setter - def monitoring_radio(self, value: Optional[MonitoringRadio]) -> None: - self._monitoring_radio = value + @submit_monitoring_radio.setter + def submit_monitoring_radio(self, value: Optional[MonitoringRadioSender]) -> None: + self._submit_monitoring_radio = value diff --git a/parsl/executors/flux/executor.py b/parsl/executors/flux/executor.py index c4926abb68..f1b981f7e0 100644 --- a/parsl/executors/flux/executor.py +++ b/parsl/executors/flux/executor.py @@ -200,7 +200,6 @@ def __init__( raise EnvironmentError("Cannot find Flux installation in PATH") self.flux_path = os.path.abspath(flux_path) self._task_id_counter = itertools.count() - self._socket = zmq.Context().socket(zmq.REP) # Assumes a launch command cannot be None or empty self.launch_cmd = launch_cmd or self.DEFAULT_LAUNCH_CMD self._submission_queue: queue.Queue = queue.Queue() @@ -213,7 +212,6 @@ def __init__( args=( self._submission_queue, self._stop_event, - self._socket, self.working_dir, self.flux_executor_kwargs, self.provider, @@ -306,11 +304,13 @@ def _submit_wrapper( If an exception is thrown, error out all submitted tasks. """ - try: - _submit_flux_jobs(submission_queue, stop_event, *args, **kwargs) - except Exception as exc: - _error_out_jobs(submission_queue, stop_event, exc) - raise + with zmq.Context() as ctx: + with ctx.socket(zmq.REP) as socket: + try: + _submit_flux_jobs(submission_queue, stop_event, socket, *args, **kwargs) + except Exception as exc: + _error_out_jobs(submission_queue, stop_event, exc) + raise def _error_out_jobs( diff --git a/parsl/executors/flux/flux_instance_manager.py b/parsl/executors/flux/flux_instance_manager.py index 3d760bb5c8..e6111796b5 100644 --- a/parsl/executors/flux/flux_instance_manager.py +++ b/parsl/executors/flux/flux_instance_manager.py @@ -27,30 +27,29 @@ def main(): parser.add_argument("hostname", help="hostname of the parent executor's socket") parser.add_argument("port", help="Port of the parent executor's socket") args = parser.parse_args() - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect( - args.protocol + "://" + gethostbyname(args.hostname) + ":" + args.port - ) - # send the path to the ``flux.job`` package - socket.send(dirname(dirname(os.path.realpath(flux.__file__))).encode()) - logging.debug("Flux package path sent.") - # collect the encapsulating Flux instance's URI - local_uri = flux.Flux().attr_get("local-uri") - hostname = gethostname() - if args.hostname == hostname: - flux_uri = local_uri - else: - flux_uri = "ssh://" + gethostname() + local_uri.replace("local://", "") - logging.debug("Flux URI is %s", flux_uri) - response = socket.recv() # get acknowledgment - logging.debug("Received acknowledgment %s", response) - socket.send(flux_uri.encode()) # send URI - logging.debug("URI sent. Blocking for response...") - response = socket.recv() # wait for shutdown message - logging.debug("Response %s received, draining flux jobs...", response) - flux.Flux().rpc("job-manager.drain").get() - logging.debug("Flux jobs drained, exiting.") + with zmq.Context() as context, context.socket(zmq.REQ) as socket: + socket.connect( + args.protocol + "://" + gethostbyname(args.hostname) + ":" + args.port + ) + # send the path to the ``flux.job`` package + socket.send(dirname(dirname(os.path.realpath(flux.__file__))).encode()) + logging.debug("Flux package path sent.") + # collect the encapsulating Flux instance's URI + local_uri = flux.Flux().attr_get("local-uri") + hostname = gethostname() + if args.hostname == hostname: + flux_uri = local_uri + else: + flux_uri = "ssh://" + gethostname() + local_uri.replace("local://", "") + logging.debug("Flux URI is %s", flux_uri) + response = socket.recv() # get acknowledgment + logging.debug("Received acknowledgment %s", response) + socket.send(flux_uri.encode()) # send URI + logging.debug("URI sent. Blocking for response...") + response = socket.recv() # wait for shutdown message + logging.debug("Response %s received, draining flux jobs...", response) + flux.Flux().rpc("job-manager.drain").get() + logging.debug("Flux jobs drained, exiting.") if __name__ == "__main__": diff --git a/parsl/executors/high_throughput/errors.py b/parsl/executors/high_throughput/errors.py index 4db7907523..9916ec506f 100644 --- a/parsl/executors/high_throughput/errors.py +++ b/parsl/executors/high_throughput/errors.py @@ -1,3 +1,36 @@ +import time + + +class ManagerLost(Exception): + """ + Task lost due to manager loss. Manager is considered lost when multiple heartbeats + have been missed. + """ + def __init__(self, manager_id: bytes, hostname: str) -> None: + self.manager_id = manager_id + self.tstamp = time.time() + self.hostname = hostname + + def __str__(self) -> str: + return ( + f"Task failure due to loss of manager {self.manager_id.decode()} on" + f" host {self.hostname}" + ) + + +class VersionMismatch(Exception): + """Manager and Interchange versions do not match""" + def __init__(self, interchange_version: str, manager_version: str): + self.interchange_version = interchange_version + self.manager_version = manager_version + + def __str__(self) -> str: + return ( + f"Manager version info {self.manager_version} does not match interchange" + f" version info {self.interchange_version}, causing a critical failure" + ) + + class WorkerLost(Exception): """Exception raised when a worker is lost """ diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index b5480e7937..c4097500f1 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -1,32 +1,33 @@ import logging import math import pickle +import subprocess import threading import typing import warnings from collections import defaultdict from concurrent.futures import Future from dataclasses import dataclass -from multiprocessing import Process from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import typeguard -import parsl.launchers from parsl import curvezmq from parsl.addresses import get_all_addresses from parsl.app.errors import RemoteExceptionWrapper from parsl.data_provider.staging import Staging from parsl.executors.errors import BadMessage, ScalingFailed -from parsl.executors.high_throughput import interchange, zmq_pipes +from parsl.executors.high_throughput import zmq_pipes from parsl.executors.high_throughput.errors import CommandClientTimeoutError +from parsl.executors.high_throughput.manager_selector import ( + ManagerSelector, + RandomManagerSelector, +) from parsl.executors.high_throughput.mpi_prefix_composer import ( - VALID_LAUNCHERS, - validate_resource_spec, + InvalidResourceSpecification, ) from parsl.executors.status_handling import BlockProviderExecutor from parsl.jobs.states import TERMINAL_STATES, JobState, JobStatus -from parsl.multiprocessing import ForkProcess from parsl.process_loggers import wrap_with_logs from parsl.providers import LocalProvider from parsl.providers.base import ExecutionProvider @@ -57,6 +58,8 @@ "--mpi-launcher={mpi_launcher} " "--available-accelerators {accelerators}") +DEFAULT_INTERCHANGE_LAUNCH_CMD = ["interchange.py"] + GENERAL_HTEX_PARAM_DOCS = """provider : :class:`~parsl.providers.base.ExecutionProvider` Provider to access computation resources. Can be one of :class:`~parsl.providers.aws.aws.EC2Provider`, :class:`~parsl.providers.cobalt.cobalt.Cobalt`, @@ -77,6 +80,10 @@ cores_per_worker, nodes_per_block, heartbeat_period ,heartbeat_threshold, logdir). For example: launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} --task_url={task_url} --result_url={result_url}" + interchange_launch_cmd : Sequence[str] + Custom sequence of command line tokens to launch the interchange process from the executor. If + undefined, the executor will use the default "interchange.py" command. + address : string An address to connect to the main Parsl process which is reachable from the network in which workers will be running. This field expects an IPv4 address (xxx.xxx.xxx.xxx). @@ -163,7 +170,8 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin, UsageIn | | | | batching | | | Parsl<---Fut-| | | load-balancing| result exception ^ | | | watchdogs | | | - | | | Q_mngmnt | | V V + | | | Result | | | | + | | | Queue | | V V | | | Thread<--|-incoming_q<---|--- +---------+ | | | | | | | | | | | | @@ -214,17 +222,6 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin, UsageIn Parsl will create names as integers starting with 0. default: empty list - - enable_mpi_mode: bool - If enabled, MPI launch prefixes will be composed for the batch scheduler based on - the nodes available in each batch job and the resource_specification dict passed - from the app. This is an experimental feature, please refer to the following doc section - before use: https://parsl.readthedocs.io/en/stable/userguide/mpi_apps.html - - mpi_launcher: str - This field is only used if enable_mpi_mode is set. Select one from the - list of supported MPI launchers = ("srun", "aprun", "mpiexec"). - default: "mpiexec" """ @typeguard.typechecked @@ -232,6 +229,7 @@ def __init__(self, label: str = 'HighThroughputExecutor', provider: ExecutionProvider = LocalProvider(), launch_cmd: Optional[str] = None, + interchange_launch_cmd: Optional[Sequence[str]] = None, address: Optional[str] = None, worker_ports: Optional[Tuple[int, int]] = None, worker_port_range: Optional[Tuple[int, int]] = (54000, 55000), @@ -252,8 +250,7 @@ def __init__(self, poll_period: int = 10, address_probe_timeout: Optional[int] = None, worker_logdir_root: Optional[str] = None, - enable_mpi_mode: bool = False, - mpi_launcher: str = "mpiexec", + manager_selector: ManagerSelector = RandomManagerSelector(), block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True, encrypted: bool = False): @@ -269,6 +266,7 @@ def __init__(self, self.prefetch_capacity = prefetch_capacity self.address = address self.address_probe_timeout = address_probe_timeout + self.manager_selector = manager_selector if self.address: self.all_addresses = address else: @@ -305,7 +303,7 @@ def __init__(self, self._task_counter = 0 self.worker_ports = worker_ports self.worker_port_range = worker_port_range - self.interchange_proc: Optional[Process] = None + self.interchange_proc: Optional[subprocess.Popen] = None self.interchange_port_range = interchange_port_range self.heartbeat_threshold = heartbeat_threshold self.heartbeat_period = heartbeat_period @@ -317,20 +315,17 @@ def __init__(self, self.encrypted = encrypted self.cert_dir = None - self.enable_mpi_mode = enable_mpi_mode - assert mpi_launcher in VALID_LAUNCHERS, \ - f"mpi_launcher must be set to one of {VALID_LAUNCHERS}" - if self.enable_mpi_mode: - assert isinstance(self.provider.launcher, parsl.launchers.SimpleLauncher), \ - "mpi_mode requires the provider to be configured to use a SimpleLauncher" - - self.mpi_launcher = mpi_launcher - if not launch_cmd: launch_cmd = DEFAULT_LAUNCH_CMD self.launch_cmd = launch_cmd + if not interchange_launch_cmd: + interchange_launch_cmd = DEFAULT_INTERCHANGE_LAUNCH_CMD + self.interchange_launch_cmd = interchange_launch_cmd + radio_mode = "htex" + enable_mpi_mode: bool = False + mpi_launcher: str = "mpiexec" def _warn_deprecated(self, old: str, new: str): warnings.warn( @@ -360,6 +355,18 @@ def worker_logdir(self): return "{}/{}".format(self.worker_logdir_root, self.label) return self.logdir + def validate_resource_spec(self, resource_specification: dict): + """HTEX does not support *any* resource_specification options and + will raise InvalidResourceSpecification is any are passed to it""" + if resource_specification: + raise InvalidResourceSpecification( + set(resource_specification.keys()), + ("HTEX does not support the supplied resource_specifications." + "For MPI applications consider using the MPIExecutor. " + "For specifications for core count/memory/walltime, consider using WorkQueueExecutor. ") + ) + return + def initialize_scaling(self): """Compose the launch command and scale out the initial blocks. """ @@ -419,20 +426,19 @@ def start(self): "127.0.0.1", self.interchange_port_range, self.cert_dir ) - self._queue_management_thread = None - self._start_queue_management_thread() + self._result_queue_thread = None + self._start_result_queue_thread() self._start_local_interchange_process() - logger.debug("Created management thread: {}".format(self._queue_management_thread)) + logger.debug("Created result queue thread: %s", self._result_queue_thread) self.initialize_scaling() @wrap_with_logs - def _queue_management_worker(self): - """Listen to the queue for task status messages and handle them. + def _result_queue_worker(self): + """Listen to the queue for task result messages and handle them. - Depending on the message, tasks will be updated with results, exceptions, - or updates. It expects the following messages: + Depending on the message, tasks will be updated with results or exceptions. .. code:: python @@ -446,10 +452,8 @@ def _queue_management_worker(self): "task_id" : "exception" : serialized exception object, on failure } - - The `None` message is a die request. """ - logger.debug("Queue management worker starting") + logger.debug("Result queue worker starting") while not self.bad_state_is_set: try: @@ -465,108 +469,114 @@ def _queue_management_worker(self): else: - if msgs is None: - logger.debug("Got None, exiting") - return + for serialized_msg in msgs: + try: + msg = pickle.loads(serialized_msg) + except pickle.UnpicklingError: + raise BadMessage("Message received could not be unpickled") - else: - for serialized_msg in msgs: + if msg['type'] == 'heartbeat': + continue + elif msg['type'] == 'result': try: - msg = pickle.loads(serialized_msg) - except pickle.UnpicklingError: - raise BadMessage("Message received could not be unpickled") + tid = msg['task_id'] + except Exception: + raise BadMessage("Message received does not contain 'task_id' field") + + if tid == -1 and 'exception' in msg: + logger.warning("Executor shutting down due to exception from interchange") + exception = deserialize(msg['exception']) + self.set_bad_state_and_fail_all(exception) + break + + task_fut = self.tasks.pop(tid) - if msg['type'] == 'heartbeat': - continue - elif msg['type'] == 'result': + if 'result' in msg: + result = deserialize(msg['result']) + task_fut.set_result(result) + + elif 'exception' in msg: try: - tid = msg['task_id'] - except Exception: - raise BadMessage("Message received does not contain 'task_id' field") - - if tid == -1 and 'exception' in msg: - logger.warning("Executor shutting down due to exception from interchange") - exception = deserialize(msg['exception']) - self.set_bad_state_and_fail_all(exception) - break - - task_fut = self.tasks.pop(tid) - - if 'result' in msg: - result = deserialize(msg['result']) - task_fut.set_result(result) - - elif 'exception' in msg: - try: - s = deserialize(msg['exception']) - # s should be a RemoteExceptionWrapper... so we can reraise it - if isinstance(s, RemoteExceptionWrapper): - try: - s.reraise() - except Exception as e: - task_fut.set_exception(e) - elif isinstance(s, Exception): - task_fut.set_exception(s) - else: - raise ValueError("Unknown exception-like type received: {}".format(type(s))) - except Exception as e: - # TODO could be a proper wrapped exception? - task_fut.set_exception( - DeserializationError("Received exception, but handling also threw an exception: {}".format(e))) - else: - raise BadMessage("Message received is neither result or exception") + s = deserialize(msg['exception']) + # s should be a RemoteExceptionWrapper... so we can reraise it + if isinstance(s, RemoteExceptionWrapper): + try: + s.reraise() + except Exception as e: + task_fut.set_exception(e) + elif isinstance(s, Exception): + task_fut.set_exception(s) + else: + raise ValueError("Unknown exception-like type received: {}".format(type(s))) + except Exception as e: + # TODO could be a proper wrapped exception? + task_fut.set_exception( + DeserializationError("Received exception, but handling also threw an exception: {}".format(e))) else: - raise BadMessage("Message received with unknown type {}".format(msg['type'])) + raise BadMessage("Message received is neither result or exception") + else: + raise BadMessage("Message received with unknown type {}".format(msg['type'])) - logger.info("Queue management worker finished") + logger.info("Result queue worker finished") - def _start_local_interchange_process(self): + def _start_local_interchange_process(self) -> None: """ Starts the interchange process locally - Starts the interchange process locally and uses an internal command queue to + Starts the interchange process locally and uses the command queue to get the worker task and result ports that the interchange has bound to. """ - self.interchange_proc = ForkProcess(target=interchange.starter, - kwargs={"client_ports": (self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - "interchange_address": self.address, - "worker_ports": self.worker_ports, - "worker_port_range": self.worker_port_range, - "hub_address": self.hub_address, - "hub_zmq_port": self.hub_zmq_port, - "logdir": self.logdir, - "heartbeat_threshold": self.heartbeat_threshold, - "poll_period": self.poll_period, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, - "cert_dir": self.cert_dir, - }, - daemon=True, - name="HTEX-Interchange" - ) - self.interchange_proc.start() + interchange_config = {"client_address": "127.0.0.1", + "client_ports": (self.outgoing_q.port, + self.incoming_q.port, + self.command_client.port), + "interchange_address": self.address, + "worker_ports": self.worker_ports, + "worker_port_range": self.worker_port_range, + "hub_address": self.hub_address, + "hub_zmq_port": self.hub_zmq_port, + "logdir": self.logdir, + "heartbeat_threshold": self.heartbeat_threshold, + "poll_period": self.poll_period, + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + "cert_dir": self.cert_dir, + "manager_selector": self.manager_selector, + "run_id": self.run_id, + } + + config_pickle = pickle.dumps(interchange_config) + + self.interchange_proc = subprocess.Popen(self.interchange_launch_cmd, stdin=subprocess.PIPE) + stdin = self.interchange_proc.stdin + assert stdin is not None, "Popen should have created an IO object (vs default None) because of PIPE mode" + + logger.debug("Popened interchange process. Writing config object") + stdin.write(config_pickle) + stdin.flush() + stdin.close() + logger.debug("Sent config object. Requesting worker ports") try: (self.worker_task_port, self.worker_result_port) = self.command_client.run("WORKER_PORTS", timeout_s=120) except CommandClientTimeoutError: - logger.error("Interchange has not completed initialization in 120s. Aborting") + logger.error("Interchange has not completed initialization. Aborting") raise Exception("Interchange failed to start") + logger.debug("Got worker ports") - def _start_queue_management_thread(self): - """Method to start the management thread as a daemon. + def _start_result_queue_thread(self): + """Method to start the result queue thread as a daemon. Checks if a thread already exists, then starts it. - Could be used later as a restart if the management thread dies. + Could be used later as a restart if the result queue thread dies. """ - if self._queue_management_thread is None: - logger.debug("Starting queue management thread") - self._queue_management_thread = threading.Thread(target=self._queue_management_worker, name="HTEX-Queue-Management-Thread") - self._queue_management_thread.daemon = True - self._queue_management_thread.start() - logger.debug("Started queue management thread") + if self._result_queue_thread is None: + logger.debug("Starting result queue thread") + self._result_queue_thread = threading.Thread(target=self._result_queue_worker, name="HTEX-Result-Queue-Thread") + self._result_queue_thread.daemon = True + self._result_queue_thread.start() + logger.debug("Started result queue thread") else: - logger.error("Management thread already exists, returning") + logger.error("Result queue thread already exists, returning") def hold_worker(self, worker_id: str) -> None: """Puts a worker on hold, preventing scheduling of additional tasks to it. @@ -640,7 +650,7 @@ def submit(self, func, resource_specification, *args, **kwargs): Future """ - validate_resource_spec(resource_specification, self.enable_mpi_mode) + self.validate_resource_spec(resource_specification) if self.bad_state_is_set: raise self.executor_exception @@ -809,12 +819,28 @@ def shutdown(self, timeout: float = 10.0): logger.info("Attempting HighThroughputExecutor shutdown") self.interchange_proc.terminate() - self.interchange_proc.join(timeout=timeout) - if self.interchange_proc.is_alive(): - logger.info("Unable to terminate Interchange process; sending SIGKILL") + try: + self.interchange_proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning("Unable to terminate Interchange process; sending SIGKILL") self.interchange_proc.kill() - self.interchange_proc.close() + logger.info("Closing ZMQ pipes") + + # These pipes are used in a thread unsafe manner. If you have traced a + # problem to this block of code, you might consider what is happening + # with other threads that access these. + + # incoming_q is not closed here because it is used by the results queue + # worker which is not shut down at this point. + + if hasattr(self, 'outgoing_q'): + logger.info("Closing outgoing_q") + self.outgoing_q.close() + + if hasattr(self, 'command_client'): + logger.info("Closing command client") + self.command_client.close() logger.info("Finished HighThroughputExecutor shutdown attempt") diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 4b3bab3563..fa0969d398 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -6,7 +6,6 @@ import pickle import platform import queue -import random import signal import sys import threading @@ -17,8 +16,11 @@ from parsl import curvezmq from parsl.app.errors import RemoteExceptionWrapper +from parsl.executors.high_throughput.errors import ManagerLost, VersionMismatch from parsl.executors.high_throughput.manager_record import ManagerRecord +from parsl.executors.high_throughput.manager_selector import ManagerSelector from parsl.monitoring.message_type import MessageType +from parsl.monitoring.radios import MonitoringRadioSender, ZMQRadioSender from parsl.process_loggers import wrap_with_logs from parsl.serialize import serialize as serialize_object from parsl.utils import setproctitle @@ -31,32 +33,6 @@ logger = logging.getLogger(LOGGER_NAME) -class ManagerLost(Exception): - ''' Task lost due to manager loss. Manager is considered lost when multiple heartbeats - have been missed. - ''' - def __init__(self, manager_id: bytes, hostname: str) -> None: - self.manager_id = manager_id - self.tstamp = time.time() - self.hostname = hostname - - def __str__(self) -> str: - return "Task failure due to loss of manager {} on host {}".format(self.manager_id.decode(), self.hostname) - - -class VersionMismatch(Exception): - ''' Manager and Interchange versions do not match - ''' - def __init__(self, interchange_version: str, manager_version: str): - self.interchange_version = interchange_version - self.manager_version = manager_version - - def __str__(self) -> str: - return "Manager version info {} does not match interchange version info {}, causing a critical failure".format( - self.manager_version, - self.interchange_version) - - class Interchange: """ Interchange is a task orchestrator for distributed systems. @@ -65,18 +41,21 @@ class Interchange: 3. Detect workers that have failed using heartbeats """ def __init__(self, - client_address: str = "127.0.0.1", - interchange_address: Optional[str] = None, - client_ports: Tuple[int, int, int] = (50055, 50056, 50057), - worker_ports: Optional[Tuple[int, int]] = None, - worker_port_range: Tuple[int, int] = (54000, 55000), - hub_address: Optional[str] = None, - hub_zmq_port: Optional[int] = None, - heartbeat_threshold: int = 60, - logdir: str = ".", - logging_level: int = logging.INFO, - poll_period: int = 10, - cert_dir: Optional[str] = None, + *, + client_address: str, + interchange_address: Optional[str], + client_ports: Tuple[int, int, int], + worker_ports: Optional[Tuple[int, int]], + worker_port_range: Tuple[int, int], + hub_address: Optional[str], + hub_zmq_port: Optional[int], + heartbeat_threshold: int, + logdir: str, + logging_level: int, + poll_period: int, + cert_dir: Optional[str], + manager_selector: ManagerSelector, + run_id: str, ) -> None: """ Parameters @@ -92,34 +71,34 @@ def __init__(self, The ports at which the client can be reached worker_ports : tuple(int, int) - The specific two ports at which workers will connect to the Interchange. Default: None + The specific two ports at which workers will connect to the Interchange. worker_port_range : tuple(int, int) The interchange picks ports at random from the range which will be used by workers. - This is overridden when the worker_ports option is set. Default: (54000, 55000) + This is overridden when the worker_ports option is set. hub_address : str The IP address at which the interchange can send info about managers to when monitoring is enabled. - Default: None (meaning monitoring disabled) + When None, monitoring is disabled. hub_zmq_port : str The port at which the interchange can send info about managers to when monitoring is enabled. - Default: None (meaning monitoring disabled) + When None, monitoring is disabled. heartbeat_threshold : int Number of seconds since the last heartbeat after which worker is considered lost. logdir : str - Parsl log directory paths. Logs and temp files go here. Default: '.' + Parsl log directory paths. Logs and temp files go here. logging_level : int - Logging level as defined in the logging module. Default: logging.INFO + Logging level as defined in the logging module. poll_period : int - The main thread polling period, in milliseconds. Default: 10ms + The main thread polling period, in milliseconds. cert_dir : str | None - Path to the certificate directory. Default: None + Path to the certificate directory. """ self.cert_dir = cert_dir self.logdir = logdir @@ -147,6 +126,8 @@ def __init__(self, self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") + self.run_id = run_id + self.hub_address = hub_address self.hub_zmq_port = hub_zmq_port @@ -184,6 +165,8 @@ def __init__(self, self.heartbeat_threshold = heartbeat_threshold + self.manager_selector = manager_selector + self.current_platform = {'parsl_v': PARSL_VERSION, 'python_v': "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, @@ -240,27 +223,16 @@ def task_puller(self) -> NoReturn: task_counter += 1 logger.debug(f"Fetched {task_counter} tasks so far") - def _create_monitoring_channel(self) -> Optional[zmq.Socket]: - if self.hub_address and self.hub_zmq_port: - logger.info("Connecting to MonitoringHub") - # This is a one-off because monitoring is unencrypted - hub_channel = zmq.Context().socket(zmq.DEALER) - hub_channel.set_hwm(0) - hub_channel.connect("tcp://{}:{}".format(self.hub_address, self.hub_zmq_port)) - logger.info("Connected to MonitoringHub") - return hub_channel - else: - return None - - def _send_monitoring_info(self, hub_channel: Optional[zmq.Socket], manager: ManagerRecord) -> None: - if hub_channel: + def _send_monitoring_info(self, monitoring_radio: Optional[MonitoringRadioSender], manager: ManagerRecord) -> None: + if monitoring_radio: logger.info("Sending message {} to MonitoringHub".format(manager)) d: Dict = cast(Dict, manager.copy()) d['timestamp'] = datetime.datetime.now() d['last_heartbeat'] = datetime.datetime.fromtimestamp(d['last_heartbeat']) + d['run_id'] = self.run_id - hub_channel.send_pyobj((MessageType.NODE_INFO, d)) + monitoring_radio.send((MessageType.NODE_INFO, d)) @wrap_with_logs(target="interchange") def _command_server(self) -> NoReturn: @@ -268,8 +240,11 @@ def _command_server(self) -> NoReturn: """ logger.debug("Command Server Starting") - # Need to create a new ZMQ socket for command server thread - hub_channel = self._create_monitoring_channel() + if self.hub_address is not None and self.hub_zmq_port is not None: + logger.debug("Creating monitoring radio to %s:%s", self.hub_address, self.hub_zmq_port) + monitoring_radio = ZMQRadioSender(self.hub_address, self.hub_zmq_port) + else: + monitoring_radio = None reply: Any # the type of reply depends on the command_req received (aka this needs dependent types...) @@ -319,7 +294,7 @@ def _command_server(self) -> NoReturn: if manager_id in self._ready_managers: m = self._ready_managers[manager_id] m['active'] = False - self._send_monitoring_info(hub_channel, m) + self._send_monitoring_info(monitoring_radio, m) else: logger.warning("Worker to hold was not in ready managers list") @@ -354,9 +329,14 @@ def start(self) -> None: # parent-process-inheritance problems. signal.signal(signal.SIGTERM, signal.SIG_DFL) - logger.info("Incoming ports bound") + logger.info("Starting main interchange method") - hub_channel = self._create_monitoring_channel() + if self.hub_address is not None and self.hub_zmq_port is not None: + logger.debug("Creating monitoring radio to %s:%s", self.hub_address, self.hub_zmq_port) + monitoring_radio = ZMQRadioSender(self.hub_address, self.hub_zmq_port) + logger.debug("Created monitoring radio") + else: + monitoring_radio = None poll_period = self.poll_period @@ -387,10 +367,10 @@ def start(self) -> None: while not kill_event.is_set(): self.socks = dict(poller.poll(timeout=poll_period)) - self.process_task_outgoing_incoming(interesting_managers, hub_channel, kill_event) - self.process_results_incoming(interesting_managers, hub_channel) - self.expire_bad_managers(interesting_managers, hub_channel) - self.expire_drained_managers(interesting_managers, hub_channel) + self.process_task_outgoing_incoming(interesting_managers, monitoring_radio, kill_event) + self.process_results_incoming(interesting_managers, monitoring_radio) + self.expire_bad_managers(interesting_managers, monitoring_radio) + self.expire_drained_managers(interesting_managers, monitoring_radio) self.process_tasks_to_send(interesting_managers) self.zmq_context.destroy() @@ -401,7 +381,7 @@ def start(self) -> None: def process_task_outgoing_incoming( self, interesting_managers: Set[bytes], - hub_channel: Optional[zmq.Socket], + monitoring_radio: Optional[MonitoringRadioSender], kill_event: threading.Event ) -> None: """Process one message from manager on the task_outgoing channel. @@ -434,6 +414,7 @@ def process_task_outgoing_incoming( self._ready_managers[manager_id] = {'last_heartbeat': time.time(), 'idle_since': time.time(), 'block_id': None, + 'start_time': msg['start_time'], 'max_capacity': 0, 'worker_count': 0, 'active': True, @@ -454,7 +435,7 @@ def process_task_outgoing_incoming( m.update(msg) # type: ignore[typeddict-item] logger.info("Registration info for manager {!r}: {}".format(manager_id, msg)) - self._send_monitoring_info(hub_channel, m) + self._send_monitoring_info(monitoring_radio, m) if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or msg['parsl_v'] != self.current_platform['parsl_v']): @@ -485,7 +466,7 @@ def process_task_outgoing_incoming( logger.error(f"Unexpected message type received from manager: {msg['type']}") logger.debug("leaving task_outgoing section") - def expire_drained_managers(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None: + def expire_drained_managers(self, interesting_managers: Set[bytes], monitoring_radio: Optional[MonitoringRadioSender]) -> None: for manager_id in list(interesting_managers): # is it always true that a draining manager will be in interesting managers? @@ -498,7 +479,7 @@ def expire_drained_managers(self, interesting_managers: Set[bytes], hub_channel: self._ready_managers.pop(manager_id) m['active'] = False - self._send_monitoring_info(hub_channel, m) + self._send_monitoring_info(monitoring_radio, m) def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None: # Check if there are tasks that could be sent to managers @@ -508,8 +489,7 @@ def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None: interesting=len(interesting_managers))) if interesting_managers and not self.pending_task_queue.empty(): - shuffled_managers = list(interesting_managers) - random.shuffle(shuffled_managers) + shuffled_managers = self.manager_selector.sort_managers(self._ready_managers, interesting_managers) while shuffled_managers and not self.pending_task_queue.empty(): # cf. the if statement above... manager_id = shuffled_managers.pop() @@ -542,7 +522,7 @@ def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None: else: logger.debug("either no interesting managers or no tasks, so skipping manager pass") - def process_results_incoming(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None: + def process_results_incoming(self, interesting_managers: Set[bytes], monitoring_radio: Optional[MonitoringRadioSender]) -> None: # Receive any results and forward to client if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: logger.debug("entering results_incoming section") @@ -562,11 +542,11 @@ def process_results_incoming(self, interesting_managers: Set[bytes], hub_channel elif r['type'] == 'monitoring': # the monitoring code makes the assumption that no # monitoring messages will be received if monitoring - # is not configured, and that hub_channel will only + # is not configured, and that monitoring_radio will only # be None when monitoring is not configurated. - assert hub_channel is not None + assert monitoring_radio is not None - hub_channel.send_pyobj(r['payload']) + monitoring_radio.send(r['payload']) elif r['type'] == 'heartbeat': logger.debug(f"Manager {manager_id!r} sent heartbeat via results connection") b_messages.append((p_message, r)) @@ -610,7 +590,7 @@ def process_results_incoming(self, interesting_managers: Set[bytes], hub_channel interesting_managers.add(manager_id) logger.debug("leaving results_incoming section") - def expire_bad_managers(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None: + def expire_bad_managers(self, interesting_managers: Set[bytes], monitoring_radio: Optional[MonitoringRadioSender]) -> None: bad_managers = [(manager_id, m) for (manager_id, m) in self._ready_managers.items() if time.time() - m['last_heartbeat'] > self.heartbeat_threshold] for (manager_id, m) in bad_managers: @@ -618,7 +598,7 @@ def expire_bad_managers(self, interesting_managers: Set[bytes], hub_channel: Opt logger.warning(f"Too many heartbeats missed for manager {manager_id!r} - removing manager") if m['active']: m['active'] = False - self._send_monitoring_info(hub_channel, m) + self._send_monitoring_info(monitoring_radio, m) logger.warning(f"Cancelling htex tasks {m['tasks']} on removed manager") for tid in m['tasks']: @@ -671,13 +651,10 @@ def start_file_logger(filename: str, level: int = logging.DEBUG, format_string: logger.addHandler(handler) -@wrap_with_logs(target="interchange") -def starter(*args: Any, **kwargs: Any) -> None: - """Start the interchange process - - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ - """ +if __name__ == "__main__": setproctitle("parsl: HTEX interchange") - # logger = multiprocessing.get_logger() - ic = Interchange(*args, **kwargs) + + config = pickle.load(sys.stdin.buffer) + + ic = Interchange(**config) ic.start() diff --git a/parsl/executors/high_throughput/manager_record.py b/parsl/executors/high_throughput/manager_record.py index 7e58b53954..a48c18cbd9 100644 --- a/parsl/executors/high_throughput/manager_record.py +++ b/parsl/executors/high_throughput/manager_record.py @@ -6,6 +6,7 @@ class ManagerRecord(TypedDict, total=False): block_id: Optional[str] + start_time: float tasks: List[Any] worker_count: int max_capacity: int diff --git a/parsl/executors/high_throughput/manager_selector.py b/parsl/executors/high_throughput/manager_selector.py new file mode 100644 index 0000000000..0ede28ee7d --- /dev/null +++ b/parsl/executors/high_throughput/manager_selector.py @@ -0,0 +1,25 @@ +import random +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Set + +from parsl.executors.high_throughput.manager_record import ManagerRecord + + +class ManagerSelector(metaclass=ABCMeta): + + @abstractmethod + def sort_managers(self, ready_managers: Dict[bytes, ManagerRecord], manager_list: Set[bytes]) -> List[bytes]: + """ Sort a given list of managers. + + Any operations pertaining to the sorting and rearrangement of the + interesting_managers Set should be performed here. + """ + pass + + +class RandomManagerSelector(ManagerSelector): + + def sort_managers(self, ready_managers: Dict[bytes, ManagerRecord], manager_list: Set[bytes]) -> List[bytes]: + c_manager_list = list(manager_list) + random.shuffle(c_manager_list) + return c_manager_list diff --git a/parsl/executors/high_throughput/mpi_executor.py b/parsl/executors/high_throughput/mpi_executor.py index 69071557c8..04b8cf5197 100644 --- a/parsl/executors/high_throughput/mpi_executor.py +++ b/parsl/executors/high_throughput/mpi_executor.py @@ -8,8 +8,13 @@ GENERAL_HTEX_PARAM_DOCS, HighThroughputExecutor, ) +from parsl.executors.high_throughput.mpi_prefix_composer import ( + VALID_LAUNCHERS, + validate_resource_spec, +) from parsl.executors.status_handling import BlockProviderExecutor from parsl.jobs.states import JobStatus +from parsl.launchers import SimpleLauncher from parsl.providers import LocalProvider from parsl.providers.base import ExecutionProvider @@ -30,6 +35,11 @@ class MPIExecutor(HighThroughputExecutor): max_workers_per_block: int Maximum number of MPI applications to run at once per block + mpi_launcher: str + Select one from the list of supported MPI launchers: + ("srun", "aprun", "mpiexec"). + default: "mpiexec" + {GENERAL_HTEX_PARAM_DOCS} """ @@ -38,6 +48,7 @@ def __init__(self, label: str = 'MPIExecutor', provider: ExecutionProvider = LocalProvider(), launch_cmd: Optional[str] = None, + interchange_launch_cmd: Optional[str] = None, address: Optional[str] = None, worker_ports: Optional[Tuple[int, int]] = None, worker_port_range: Optional[Tuple[int, int]] = (54000, 55000), @@ -59,13 +70,13 @@ def __init__(self, super().__init__( # Hard-coded settings cores_per_worker=1e-9, # Ensures there will be at least an absurd number of workers - enable_mpi_mode=True, max_workers_per_node=max_workers_per_block, # Everything else label=label, provider=provider, launch_cmd=launch_cmd, + interchange_launch_cmd=interchange_launch_cmd, address=address, worker_ports=worker_ports, worker_port_range=worker_port_range, @@ -80,9 +91,21 @@ def __init__(self, poll_period=poll_period, address_probe_timeout=address_probe_timeout, worker_logdir_root=worker_logdir_root, - mpi_launcher=mpi_launcher, block_error_handler=block_error_handler, encrypted=encrypted ) + self.enable_mpi_mode = True + self.mpi_launcher = mpi_launcher self.max_workers_per_block = max_workers_per_block + + if not isinstance(self.provider.launcher, SimpleLauncher): + raise TypeError("mpi_mode requires the provider to be configured to use a SimpleLauncher") + + if mpi_launcher not in VALID_LAUNCHERS: + raise ValueError(f"mpi_launcher set to:{mpi_launcher} must be set to one of {VALID_LAUNCHERS}") + + self.mpi_launcher = mpi_launcher + + def validate_resource_spec(self, resource_specification: dict): + return validate_resource_spec(resource_specification) diff --git a/parsl/executors/high_throughput/mpi_prefix_composer.py b/parsl/executors/high_throughput/mpi_prefix_composer.py index 78c5d8b867..0125d9a532 100644 --- a/parsl/executors/high_throughput/mpi_prefix_composer.py +++ b/parsl/executors/high_throughput/mpi_prefix_composer.py @@ -21,14 +21,15 @@ def __str__(self): class InvalidResourceSpecification(Exception): """Exception raised when Invalid input is supplied via resource specification""" - def __init__(self, invalid_keys: Set[str]): + def __init__(self, invalid_keys: Set[str], message: str = ''): self.invalid_keys = invalid_keys + self.message = message def __str__(self): - return f"Invalid resource specification options supplied: {self.invalid_keys}" + return f"Invalid resource specification options supplied: {self.invalid_keys} {self.message}" -def validate_resource_spec(resource_spec: Dict[str, str], is_mpi_enabled: bool): +def validate_resource_spec(resource_spec: Dict[str, str]): """Basic validation of keys in the resource_spec Raises: InvalidResourceSpecification if the resource_spec @@ -38,7 +39,7 @@ def validate_resource_spec(resource_spec: Dict[str, str], is_mpi_enabled: bool): # empty resource_spec when mpi_mode is set causes parsl to hang # ref issue #3427 - if is_mpi_enabled and len(user_keys) == 0: + if len(user_keys) == 0: raise MissingResourceSpecification('MPI mode requires optional parsl_resource_specification keyword argument to be configured') legal_keys = set(("ranks_per_node", diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 5a3b383dad..59efe501f1 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -9,6 +9,7 @@ import pickle import platform import queue +import subprocess import sys import threading import time @@ -183,6 +184,7 @@ def __init__(self, *, self.uid = uid self.block_id = block_id + self.start_time = time.time() self.enable_mpi_mode = enable_mpi_mode self.mpi_launcher = mpi_launcher @@ -262,6 +264,7 @@ def create_reg_message(self): 'worker_count': self.worker_count, 'uid': self.uid, 'block_id': self.block_id, + 'start_time': self.start_time, 'prefetch_capacity': self.prefetch_capacity, 'max_capacity': self.worker_count + self.prefetch_capacity, 'os': platform.system(), @@ -733,7 +736,26 @@ def worker( # If desired, pin to accelerator if accelerator is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = accelerator + + # If CUDA devices, find total number of devices to allow for MPS + # See: https://developer.nvidia.com/system-management-interface + nvidia_smi_cmd = "nvidia-smi -L > /dev/null && nvidia-smi -L | wc -l" + nvidia_smi_ret = subprocess.run(nvidia_smi_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if nvidia_smi_ret.returncode == 0: + num_cuda_devices = int(nvidia_smi_ret.stdout.split()[0]) + else: + num_cuda_devices = None + + try: + if num_cuda_devices is not None: + procs_per_cuda_device = pool_size // num_cuda_devices + partitioned_accelerator = str(int(accelerator) // procs_per_cuda_device) # multiple workers will share a GPU + os.environ["CUDA_VISIBLE_DEVICES"] = partitioned_accelerator + logger.info(f'Pinned worker to partitioned cuda device: {partitioned_accelerator}') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = accelerator + except (TypeError, ValueError, ZeroDivisionError): + os.environ["CUDA_VISIBLE_DEVICES"] = accelerator os.environ["ROCR_VISIBLE_DEVICES"] = accelerator os.environ["ZE_AFFINITY_MASK"] = accelerator os.environ["ZE_ENABLE_PCI_ID_DEVICE_ORDER"] = '1' diff --git a/parsl/executors/radical/executor.py b/parsl/executors/radical/executor.py index c7ea1a8dd6..93b4b38bbd 100644 --- a/parsl/executors/radical/executor.py +++ b/parsl/executors/radical/executor.py @@ -9,7 +9,7 @@ import time from concurrent.futures import Future from functools import partial -from pathlib import Path, PosixPath +from pathlib import PosixPath from typing import Dict, Optional import requests @@ -24,7 +24,7 @@ from parsl.serialize.errors import DeserializationError, SerializationError from parsl.utils import RepresentationMixin -from .rpex_resources import ResourceConfig +from .rpex_resources import CLIENT, MPI, ResourceConfig try: import radical.pilot as rp @@ -59,7 +59,7 @@ class RadicalPilotExecutor(ParslExecutor, RepresentationMixin): ``rp.PilotManager`` and ``rp.TaskManager``. 2. "translate": Unwrap, identify, and parse Parsl ``apps`` into ``rp.TaskDescription``. 3. "submit": Submit Parsl apps to ``rp.TaskManager``. - 4. "shut_down": Shut down the RADICAL-Pilot runtime and all associated components. + 4. "shutdown": Shut down the RADICAL-Pilot runtime and all associated components. Here is a diagram @@ -138,19 +138,26 @@ def __init__(self, self.future_tasks: Dict[str, Future] = {} if rpex_cfg: - self.rpex_cfg = rpex_cfg + self.rpex_cfg = rpex_cfg.get_config() elif not rpex_cfg and 'local' in resource: - self.rpex_cfg = ResourceConfig() + self.rpex_cfg = ResourceConfig().get_config() else: - raise ValueError('Resource config file must be ' - 'specified for a non-local execution') + raise ValueError('Resource config must be ' + 'specified for a non-local resources') def task_state_cb(self, task, state): """ Update the state of Parsl Future apps Based on RP task state callbacks. """ - if not task.uid.startswith('master'): + # check the Master/Worker state + if task.mode in [rp.RAPTOR_MASTER, rp.RAPTOR_WORKER]: + if state == rp.FAILED: + exception = RuntimeError(f'{task.uid} failed with internal error: {task.stderr}') + self._fail_all_tasks(exception) + + # check all other tasks state + else: parsl_task = self.future_tasks[task.uid] if state == rp.DONE: @@ -186,6 +193,23 @@ def task_state_cb(self, task, state): else: parsl_task.set_exception('Task failed for an unknown reason') + def _fail_all_tasks(self, exception): + """ + Fail all outstanding tasks with the given exception. + + This method iterates through all outstanding tasks in the + `_future_tasks` dictionary, which have not yet completed, + and sets the provided exception as their result, indicating + a failure. + + Parameters: + - exception: The exception to be set as the result for all + outstanding tasks. + """ + for fut_task in self.future_tasks.values(): + if not fut_task.done(): + fut_task.set_exception(exception) + def start(self): """Create the Pilot component and pass it. """ @@ -202,63 +226,62 @@ def start(self): 'resource': self.resource} if not self.resource or 'local' in self.resource: - # move the agent sandbox to the working dir mainly - # for debugging purposes. This will allow parsl - # to include the agent sandbox with the ci artifacts. - if os.environ.get("LOCAL_SANDBOX"): - pd_init['sandbox'] = self.run_dir - os.environ["RADICAL_LOG_LVL"] = "DEBUG" - - logger.info("RPEX will be running in the local mode") + os.environ["RADICAL_LOG_LVL"] = "DEBUG" + logger.info("RPEX will be running in local mode") pd = rp.PilotDescription(pd_init) pd.verify() - self.rpex_cfg = self.rpex_cfg._get_cfg_file(path=self.run_dir) - cfg = ru.Config(cfg=ru.read_json(self.rpex_cfg)) + # start RP's main components TMGR, PMGR and Pilot + self.tmgr = rp.TaskManager(session=self.session) + self.pmgr = rp.PilotManager(session=self.session) + self.pilot = self.pmgr.submit_pilots(pd) - self.master = cfg.master_descr - self.n_masters = cfg.n_masters + if not self.pilot.description.get('cores') or not self.pilot.description.get('nodes'): + logger.warning('no "cores/nodes" per pilot were set, using default resources') + + self.tmgr.add_pilots(self.pilot) + self.tmgr.register_callback(self.task_state_cb) - tds = list() - master_path = '{0}/rpex_master.py'.format(PWD) worker_path = '{0}/rpex_worker.py'.format(PWD) - for i in range(self.n_masters): - td = rp.TaskDescription(self.master) - td.mode = rp.RAPTOR_MASTER - td.uid = ru.generate_id('master.%(item_counter)06d', ru.ID_CUSTOM, + self.masters = [] + + logger.info(f'Starting {self.rpex_cfg.n_masters} masters and {self.rpex_cfg.n_workers} workers for each master') + + # create N masters + for _ in range(self.rpex_cfg.n_masters): + md = rp.TaskDescription(self.rpex_cfg.master_descr) + md.uid = ru.generate_id('rpex.master.%(item_counter)06d', ru.ID_CUSTOM, ns=self.session.uid) - td.ranks = 1 - td.cores_per_rank = 1 - td.arguments = [self.rpex_cfg, i] - td.input_staging = self._stage_files([File(master_path), - File(worker_path), - File(self.rpex_cfg)], mode='in') - tds.append(td) - self.pmgr = rp.PilotManager(session=self.session) - self.tmgr = rp.TaskManager(session=self.session) + # submit the master to the TMGR + master = self.tmgr.submit_raptors(md)[0] + self.masters.append(master) - # submit pilot(s) - pilot = self.pmgr.submit_pilots(pd) - if not pilot.description.get('cores'): - logger.warning('no "cores" per pilot was set, using default resources {0}'.format(pilot.resources)) + workers = [] + # create N workers for each master and submit them to the TMGR + for _ in range(self.rpex_cfg.n_workers): + wd = rp.TaskDescription(self.rpex_cfg.worker_descr) + wd.uid = ru.generate_id('rpex.worker.%(item_counter)06d', ru.ID_CUSTOM, + ns=self.session.uid) + wd.raptor_id = master.uid + wd.input_staging = self._stage_files([File(worker_path)], mode='in') + workers.append(wd) - self.tmgr.submit_tasks(tds) + self.tmgr.submit_workers(workers) + + self.select_master = self._cyclic_master_selector() # prepare or use the current env for the agent/pilot side environment - if cfg.pilot_env_mode != 'client': - logger.info("creating {0} environment for the executor".format(cfg.pilot_env.name)) - pilot.prepare_env(env_name=cfg.pilot_env.name, - env_spec=cfg.pilot_env.as_dict()) + if self.rpex_cfg.pilot_env_mode != CLIENT: + logger.info("creating {0} environment for the executor".format(self.rpex_cfg.pilot_env.name)) + self.pilot.prepare_env(env_name=self.rpex_cfg.pilot_env.name, + env_spec=self.rpex_cfg.pilot_env.as_dict()) else: client_env = sys.prefix logger.info("reusing ({0}) environment for the executor".format(client_env)) - self.tmgr.add_pilots(pilot) - self.tmgr.register_callback(self.task_state_cb) - # create a bulking thread to run the actual task submission # to RP in bulks if self.bulk_mode: @@ -272,8 +295,21 @@ def start(self): self._bulk_thread.daemon = True self._bulk_thread.start() + logger.info('bulk mode is on, submitting tasks in bulks') + return True + def _cyclic_master_selector(self): + """ + Balance tasks submission across N masters and N workers + """ + current_master = 0 + masters_uids = [m.uid for m in self.masters] + + while True: + yield masters_uids[current_master] + current_master = (current_master + 1) % len(self.masters) + def unwrap(self, func, args): """ Unwrap a Parsl app and its args for further processing. @@ -364,22 +400,25 @@ def task_translate(self, tid, func, parsl_resource_specification, args, kwargs): # This is the default mode where the bash_app will be executed as # as a single core process by RP. For cores > 1 the user must use - # above or use MPI functions if their code is Python. + # task.mode=rp.TASK_EXECUTABLE (above) or use MPI functions if their + # code is Python. else: task.mode = rp.TASK_PROC - task.raptor_id = 'master.%06d' % (tid % self.n_masters) + task.raptor_id = next(self.select_master) task.executable = self._pack_and_apply_message(func, args, kwargs) elif PYTHON in task_type or not task_type: task.mode = rp.TASK_FUNCTION - task.raptor_id = 'master.%06d' % (tid % self.n_masters) + task.raptor_id = next(self.select_master) if kwargs.get('walltime'): func = timeout(func, kwargs['walltime']) - # we process MPI function differently - if 'comm' in kwargs: + # Check how to serialize the function object + if MPI in self.rpex_cfg.worker_type.lower(): + task.use_mpi = True task.function = rp.PythonTask(func, *args, **kwargs) else: + task.use_mpi = False task.function = self._pack_and_apply_message(func, args, kwargs) task.input_staging = self._stage_files(kwargs.get("inputs", []), @@ -394,7 +433,7 @@ def task_translate(self, tid, func, parsl_resource_specification, args, kwargs): try: task.verify() except ru.typeddict.TDKeyError as e: - raise Exception(f'{e}. Please check Radical.Pilot TaskDescription documentation') + raise Exception(f'{e}. Please check: https://radicalpilot.readthedocs.io/en/stable/ documentation') return task @@ -413,7 +452,11 @@ def _pack_and_apply_message(self, func, args, kwargs): def _unpack_and_set_parsl_exception(self, parsl_task, exception): try: - s = rp.utils.deserialize_bson(exception) + try: + s = rp.utils.deserialize_bson(exception) + except Exception: + s = exception + if isinstance(s, RemoteExceptionWrapper): try: s.reraise() @@ -421,6 +464,8 @@ def _unpack_and_set_parsl_exception(self, parsl_task, exception): parsl_task.set_exception(e) elif isinstance(s, Exception): parsl_task.set_exception(s) + elif isinstance(s, str): + parsl_task.set_exception(eval(s)) else: raise ValueError("Unknown exception-like type received: {}".format(type(s))) except Exception as e: @@ -440,16 +485,10 @@ def _set_stdout_stderr(self, task, kwargs): elif isinstance(k_val, PosixPath): k_val = k_val.__str__() - # if the stderr/out has no path - # then we consider it local and - # we just set the path to the cwd - if '/' not in k_val: - k_val = CWD + '/' + k_val - - # finally set the stderr/out to - # the desired name by the user + # set the stderr/out to the desired + # name by the user setattr(task, k, k_val) - task.sandbox = Path(k_val).parent.__str__() + task.sandbox = CWD def _stage_files(self, files, mode): """ @@ -477,7 +516,7 @@ def _stage_files(self, files, mode): # this indicates that the user # did not provided a specific # output file and RP will stage out - # the task.output from pilot://task_folder + # the task.stdout from pilot://task_folder # to the CWD or file.url if '/' not in file.url: f = {'source': file.filename, @@ -548,7 +587,8 @@ def submit(self, func, resource_specification, *args, **kwargs): def shutdown(self, hub=True, targets='all', block=False): """Shutdown the executor, including all RADICAL-Pilot components.""" - logger.info("RadicalPilotExecutor shutdown") + logger.info("RadicalPilotExecutor is terminating...") self.session.close(download=True) + logger.info("RadicalPilotExecutor is terminated.") return True diff --git a/parsl/executors/radical/rpex_master.py b/parsl/executors/radical/rpex_master.py deleted file mode 100755 index 6d3627e46f..0000000000 --- a/parsl/executors/radical/rpex_master.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python3 - -import sys - -import radical.pilot as rp -import radical.utils as ru - -# ------------------------------------------------------------------------------ -# -if __name__ == '__main__': - - # The purpose of this master is to (a) spawn a set or workers - # within the same allocation, (b) to distribute work items to - # those workers, and (c) to collect the responses again. - cfg_fname = str(sys.argv[1]) - cfg = ru.Config(cfg=ru.read_json(cfg_fname)) - cfg.rank = int(sys.argv[2]) - - worker_descr = cfg.worker_descr - n_workers = cfg.n_workers - gpus_per_node = cfg.gpus_per_node - cores_per_node = cfg.cores_per_node - nodes_per_worker = cfg.nodes_per_worker - - # create a master class instance - this will establish communication - # to the pilot agent - master = rp.raptor.Master(cfg) - - # insert `n` worker into the agent. The agent will schedule (place) - # those workers and execute them. - worker_descr['ranks'] = nodes_per_worker * cores_per_node - worker_descr['gpus_per_rank'] = nodes_per_worker * gpus_per_node - worker_ids = master.submit_workers( - [rp.TaskDescription(worker_descr) for _ in range(n_workers)]) - - # wait for all workers - master.wait_workers() - master.start() - master.join() - -# ------------------------------------------------------------------------------ diff --git a/parsl/executors/radical/rpex_resources.py b/parsl/executors/radical/rpex_resources.py index f4daf9aa19..c337ee33b1 100644 --- a/parsl/executors/radical/rpex_resources.py +++ b/parsl/executors/radical/rpex_resources.py @@ -5,6 +5,7 @@ _setup_paths: List[str] = [] try: import radical.pilot as rp + import radical.utils as ru except ImportError: pass @@ -103,7 +104,7 @@ class ResourceConfig: python_v: str = f'{sys.version_info[0]}.{sys.version_info[1]}' worker_type: str = DEFAULT_WORKER - def _get_cfg_file(cls, path=None): + def get_config(cls, path=None): # Default ENV mode for RP is to reuse # the client side. If this is not the case, @@ -121,6 +122,7 @@ def _get_cfg_file(cls, path=None): cfg = { 'n_masters': cls.masters, 'n_workers': cls.workers, + 'worker_type': cls.worker_type, 'gpus_per_node': cls.worker_gpus_per_node, 'cores_per_node': cls.worker_cores_per_node, 'cores_per_master': cls.cores_per_master, @@ -138,9 +140,10 @@ def _get_cfg_file(cls, path=None): 'pilot_env_mode': cls.pilot_env_mode, 'master_descr': { + "ranks": 1, + "cores_per_rank": 1, "mode": rp.RAPTOR_MASTER, "named_env": cls.pilot_env_name, - "executable": "python3 rpex_master.py", }, 'worker_descr': { @@ -149,12 +152,16 @@ def _get_cfg_file(cls, path=None): "raptor_file": "./rpex_worker.py", "raptor_class": cls.worker_type if cls.worker_type.lower() != MPI else MPI_WORKER, + "ranks": cls.nodes_per_worker * cls.worker_cores_per_node, + "gpus_per_rank": cls.nodes_per_worker * cls.worker_gpus_per_node, }} - # Convert the class instance to a cfg file. - config_path = 'rpex.cfg' + # Convert the class instance to a Json file or a Config dict. if path: + config_path = 'rpex.cfg' config_path = path + '/' + config_path - with open(config_path, 'w') as f: - json.dump(cfg, f, indent=4) - return config_path + with open(config_path, 'w') as f: + json.dump(cfg, f, indent=4) + else: + config_obj = ru.Config(from_dict=cfg) + return config_obj diff --git a/parsl/executors/status_handling.py b/parsl/executors/status_handling.py index 4d29439670..34db2300f6 100644 --- a/parsl/executors/status_handling.py +++ b/parsl/executors/status_handling.py @@ -12,7 +12,7 @@ from parsl.executors.base import ParslExecutor from parsl.executors.errors import BadStateException, ScalingFailed from parsl.jobs.error_handlers import noop_error_handler, simple_error_handler -from parsl.jobs.states import JobState, JobStatus +from parsl.jobs.states import TERMINAL_STATES, JobState, JobStatus from parsl.monitoring.message_type import MessageType from parsl.providers.base import ExecutionProvider from parsl.utils import AtomicIDCounter @@ -59,20 +59,28 @@ def __init__(self, *, else: self.block_error_handler = block_error_handler - # errors can happen during the submit call to the provider; this is used - # to keep track of such errors so that they can be handled in one place - # together with errors reported by status() - self._simulated_status: Dict[str, JobStatus] = {} self._executor_bad_state = threading.Event() self._executor_exception: Optional[Exception] = None self._block_id_counter = AtomicIDCounter() self._tasks = {} # type: Dict[object, Future] + + self._last_poll_time = 0.0 + + # these four structures track, in loosely coordinated fashion, the + # existence of blocks and jobs and how to map between their + # identifiers. self.blocks_to_job_id = {} # type: Dict[str, str] self.job_ids_to_block = {} # type: Dict[str, str] - self._last_poll_time = 0.0 + # errors can happen during the submit call to the provider; this is used + # to keep track of such errors so that they can be handled in one place + # together with errors reported by status() + self._simulated_status: Dict[str, JobStatus] = {} + + # this stores an approximation (sometimes delayed) of the latest status + # of pending, active and recently terminated blocks self._status = {} # type: Dict[str, JobStatus] def _make_status_dict(self, block_ids: List[str], status_list: List[JobStatus]) -> Dict[str, JobStatus]: @@ -113,20 +121,6 @@ def outstanding(self) -> int: raise NotImplementedError("Classes inheriting from BlockProviderExecutor must implement " "outstanding()") - def status(self) -> Dict[str, JobStatus]: - """Return the status of all jobs/blocks currently known to this executor. - - :return: a dictionary mapping block ids (in string) to job status - """ - if self._provider: - block_ids, job_ids = self._get_block_and_job_ids() - status = self._make_status_dict(block_ids, self._provider.status(job_ids)) - else: - status = {} - status.update(self._simulated_status) - - return status - def set_bad_state_and_fail_all(self, exception: Exception): """Allows external error handlers to mark this executor as irrecoverably bad and cause all tasks submitted to it now and in the future to fail. The executor is responsible @@ -173,41 +167,82 @@ def tasks(self) -> Dict[object, Future]: def provider(self): return self._provider - def _filter_scale_in_ids(self, to_kill, killed): + def _filter_scale_in_ids(self, to_kill: Sequence[Any], killed: Sequence[bool]) -> Sequence[Any]: """ Filter out job id's that were not killed """ assert len(to_kill) == len(killed) + + if False in killed: + killed_job_ids = [jid for jid, k in zip(to_kill, killed) if k] + not_killed_job_ids = [jid for jid, k in zip(to_kill, killed) if not k] + logger.warning("Some jobs were not killed successfully: " + f"killed jobs: {killed_job_ids}, " + f"not-killed jobs: {not_killed_job_ids}") + # Filters first iterable by bool values in second return list(compress(to_kill, killed)) - def scale_out(self, blocks: int = 1) -> List[str]: + def scale_out_facade(self, n: int) -> List[str]: """Scales out the number of blocks by "blocks" """ if not self.provider: raise ScalingFailed(self, "No execution provider available") block_ids = [] - logger.info(f"Scaling out by {blocks} blocks") - for _ in range(blocks): + monitoring_status_changes = {} + logger.info(f"Scaling out by {n} blocks") + for _ in range(n): block_id = str(self._block_id_counter.get_id()) logger.info(f"Allocated block ID {block_id}") try: job_id = self._launch_block(block_id) + + pending_status = JobStatus(JobState.PENDING) + self.blocks_to_job_id[block_id] = job_id self.job_ids_to_block[job_id] = block_id + self._status[block_id] = pending_status + + monitoring_status_changes[block_id] = pending_status block_ids.append(block_id) + except Exception as ex: - self._simulated_status[block_id] = JobStatus(JobState.FAILED, "Failed to start block {}: {}".format(block_id, ex)) + failed_status = JobStatus(JobState.FAILED, "Failed to start block {}: {}".format(block_id, ex)) + self._simulated_status[block_id] = failed_status + self._status[block_id] = failed_status + + self.send_monitoring_info(monitoring_status_changes) return block_ids - @abstractmethod def scale_in(self, blocks: int) -> List[str]: """Scale in method. Cause the executor to reduce the number of blocks by count. + The default implementation will kill blocks without regard to their + status or whether they are executing tasks. Executors with more + nuanced scaling strategies might overload this method to work with + that strategy - see the HighThroughputExecutor for an example of that. + :return: A list of block ids corresponding to the blocks that were removed. """ - pass + + active_blocks = [block_id for block_id, status in self._status.items() + if status.state not in TERMINAL_STATES] + + block_ids_to_kill = active_blocks[:blocks] + + job_ids_to_kill = [self.blocks_to_job_id[block] for block in block_ids_to_kill] + + # Cancel the blocks provisioned + if self.provider: + logger.info(f"Scaling in jobs: {job_ids_to_kill}") + r = self.provider.cancel(job_ids_to_kill) + job_ids = self._filter_scale_in_ids(job_ids_to_kill, r) + block_ids_killed = [self.job_ids_to_block[job_id] for job_id in job_ids] + return block_ids_killed + else: + logger.error("No execution provider available to scale in") + return [] def _launch_block(self, block_id: str) -> Any: launch_cmd = self._get_launch_command(block_id) @@ -241,10 +276,10 @@ def workers_per_node(self) -> Union[int, float]: def send_monitoring_info(self, status: Dict) -> None: # Send monitoring info for HTEX when monitoring enabled - if self.monitoring_radio: + if self.submit_monitoring_radio: msg = self.create_monitoring_info(status) - logger.debug("Sending message {} to hub from job status poller".format(msg)) - self.monitoring_radio.send((MessageType.BLOCK_INFO, msg)) + logger.debug("Sending block monitoring message: %r", msg) + self.submit_monitoring_radio.send((MessageType.BLOCK_INFO, msg)) def create_monitoring_info(self, status: Dict[str, JobStatus]) -> Sequence[object]: """Create a monitoring message for each block based on the poll status. @@ -276,6 +311,20 @@ def poll_facade(self) -> None: if delta_status: self.send_monitoring_info(delta_status) + def status(self) -> Dict[str, JobStatus]: + """Return the status of all jobs/blocks currently known to this executor. + + :return: a dictionary mapping block ids (in string) to job status + """ + if self._provider: + block_ids, job_ids = self._get_block_and_job_ids() + status = self._make_status_dict(block_ids, self._provider.status(job_ids)) + else: + status = {} + status.update(self._simulated_status) + + return status + @property def status_facade(self) -> Dict[str, JobStatus]: """Return the status of all jobs/blocks of the executor of this poller. @@ -302,13 +351,3 @@ def scale_in_facade(self, n: int, max_idletime: Optional[float] = None) -> List[ del self._status[block_id] self.send_monitoring_info(new_status) return block_ids - - def scale_out_facade(self, n: int) -> List[str]: - block_ids = self.scale_out(n) - if block_ids is not None: - new_status = {} - for block_id in block_ids: - new_status[block_id] = JobStatus(JobState.PENDING) - self.send_monitoring_info(new_status) - self._status.update(new_status) - return block_ids diff --git a/parsl/executors/taskvine/executor.py b/parsl/executors/taskvine/executor.py index 6cfedf92bb..2e1efb211f 100644 --- a/parsl/executors/taskvine/executor.py +++ b/parsl/executors/taskvine/executor.py @@ -573,24 +573,6 @@ def outstanding(self) -> int: def workers_per_node(self) -> Union[int, float]: return 1 - def scale_in(self, count: int) -> List[str]: - """Scale in method. Cancel a given number of blocks - """ - # Obtain list of blocks to kill - to_kill = list(self.blocks_to_job_id.keys())[:count] - kill_ids = [self.blocks_to_job_id[block] for block in to_kill] - - # Cancel the blocks provisioned - if self.provider: - logger.info(f"Scaling in jobs: {kill_ids}") - r = self.provider.cancel(kill_ids) - job_ids = self._filter_scale_in_ids(kill_ids, r) - block_ids_killed = [self.job_ids_to_block[jid] for jid in job_ids] - return block_ids_killed - else: - logger.error("No execution provider available to scale") - return [] - def shutdown(self, *args, **kwargs): """Shutdown the executor. Sets flag to cancel the submit process and collector thread, which shuts down the TaskVine system submission. @@ -607,11 +589,13 @@ def shutdown(self, *args, **kwargs): # Join all processes before exiting logger.debug("Joining on submit process") self._submit_process.join() + self._submit_process.close() logger.debug("Joining on collector thread") self._collector_thread.join() if self.worker_launch_method == 'factory': logger.debug("Joining on factory process") self._factory_process.join() + self._factory_process.close() # Shutdown multiprocessing queues self._ready_task_queue.close() diff --git a/parsl/executors/workqueue/executor.py b/parsl/executors/workqueue/executor.py index 0b931bbc31..ae39f8c118 100644 --- a/parsl/executors/workqueue/executor.py +++ b/parsl/executors/workqueue/executor.py @@ -215,6 +215,13 @@ class WorkQueueExecutor(BlockProviderExecutor, putils.RepresentationMixin): This requires a version of Work Queue / cctools after commit 874df524516441da531b694afc9d591e8b134b73 (release 7.5.0 is too early). Default is False. + + scaling_cores_per_worker: int + When using Parsl scaling, this specifies the number of cores that a + worker is expected to have available for computation. Default 1. This + parameter can be ignored when using a fixed number of blocks, or when + using one task per worker (by omitting a ``cores`` resource + specifiation for each task). """ radio_mode = "filesystem" @@ -244,12 +251,14 @@ def __init__(self, full_debug: bool = True, worker_executable: str = 'work_queue_worker', function_dir: Optional[str] = None, - coprocess: bool = False): + coprocess: bool = False, + scaling_cores_per_worker: int = 1): BlockProviderExecutor.__init__(self, provider=provider, block_error_handler=True) if not _work_queue_enabled: raise OptionalModuleMissing(['work_queue'], "WorkQueueExecutor requires the work_queue module.") + self.scaling_cores_per_worker = scaling_cores_per_worker self.label = label self.task_queue = multiprocessing.Queue() # type: multiprocessing.Queue self.collector_queue = multiprocessing.Queue() # type: multiprocessing.Queue @@ -469,6 +478,8 @@ def submit(self, func, resource_specification, *args, **kwargs): # Create a Future object and have it be mapped from the task ID in the tasks dictionary fu = Future() fu.parsl_executor_task_id = executor_task_id + assert isinstance(resource_specification, dict) + fu.resource_specification = resource_specification logger.debug("Getting tasks_lock to set WQ-level task entry") with self.tasks_lock: logger.debug("Got tasks_lock to set WQ-level task entry") @@ -654,38 +665,29 @@ def initialize_scaling(self): @property def outstanding(self) -> int: - """Count the number of outstanding tasks. This is inefficiently + """Count the number of outstanding slots required. This is inefficiently implemented and probably could be replaced with a counter. """ + logger.debug("Calculating outstanding task slot load") outstanding = 0 + tasks = 0 # only for log message... with self.tasks_lock: for fut in self.tasks.values(): if not fut.done(): - outstanding += 1 - logger.debug(f"Counted {outstanding} outstanding tasks") + # if a task does not specify a core count, Work Queue will allocate an entire + # worker node to that task. That's approximated here by saying that it uses + # scaling_cores_per_worker. + resource_spec = getattr(fut, 'resource_specification', {}) + cores = resource_spec.get('cores', self.scaling_cores_per_worker) + + outstanding += cores + tasks += 1 + logger.debug(f"Counted {tasks} outstanding tasks with {outstanding} outstanding slots") return outstanding @property def workers_per_node(self) -> Union[int, float]: - return 1 - - def scale_in(self, count: int) -> List[str]: - """Scale in method. - """ - # Obtain list of blocks to kill - to_kill = list(self.blocks_to_job_id.keys())[:count] - kill_ids = [self.blocks_to_job_id[block] for block in to_kill] - - # Cancel the blocks provisioned - if self.provider: - logger.info(f"Scaling in jobs: {kill_ids}") - r = self.provider.cancel(kill_ids) - job_ids = self._filter_scale_in_ids(kill_ids, r) - block_ids_killed = [self.job_ids_to_block[jid] for jid in job_ids] - return block_ids_killed - else: - logger.error("No execution provider available to scale in") - return [] + return self.scaling_cores_per_worker def shutdown(self, *args, **kwargs): """Shutdown the executor. Sets flag to cancel the submit process and @@ -702,6 +704,8 @@ def shutdown(self, *args, **kwargs): logger.debug("Joining on submit process") self.submit_process.join() + self.submit_process.close() + logger.debug("Joining on collector thread") self.collector_thread.join() diff --git a/parsl/monitoring/db_manager.py b/parsl/monitoring/db_manager.py index 9f19cd9f4d..053c98d598 100644 --- a/parsl/monitoring/db_manager.py +++ b/parsl/monitoring/db_manager.py @@ -1,11 +1,14 @@ import datetime import logging +import multiprocessing.queues as mpq import os import queue import threading import time from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, cast +import typeguard + from parsl.dataflow.states import States from parsl.errors import OptionalModuleMissing from parsl.log_utils import set_file_logger @@ -283,7 +286,7 @@ def __init__(self, ): self.workflow_end = False - self.workflow_start_message = None # type: Optional[MonitoringMessage] + self.workflow_start_message: Optional[MonitoringMessage] = None self.logdir = logdir os.makedirs(self.logdir, exist_ok=True) @@ -299,21 +302,21 @@ def __init__(self, self.batching_interval = batching_interval self.batching_threshold = batching_threshold - self.pending_priority_queue = queue.Queue() # type: queue.Queue[TaggedMonitoringMessage] - self.pending_node_queue = queue.Queue() # type: queue.Queue[MonitoringMessage] - self.pending_block_queue = queue.Queue() # type: queue.Queue[MonitoringMessage] - self.pending_resource_queue = queue.Queue() # type: queue.Queue[MonitoringMessage] + self.pending_priority_queue: queue.Queue[TaggedMonitoringMessage] = queue.Queue() + self.pending_node_queue: queue.Queue[MonitoringMessage] = queue.Queue() + self.pending_block_queue: queue.Queue[MonitoringMessage] = queue.Queue() + self.pending_resource_queue: queue.Queue[MonitoringMessage] = queue.Queue() def start(self, - priority_queue: "queue.Queue[TaggedMonitoringMessage]", - node_queue: "queue.Queue[MonitoringMessage]", - block_queue: "queue.Queue[MonitoringMessage]", - resource_queue: "queue.Queue[MonitoringMessage]") -> None: + priority_queue: mpq.Queue, + node_queue: mpq.Queue, + block_queue: mpq.Queue, + resource_queue: mpq.Queue) -> None: self._kill_event = threading.Event() self._priority_queue_pull_thread = threading.Thread(target=self._migrate_logs_to_internal, args=( - priority_queue, 'priority', self._kill_event,), + priority_queue, self._kill_event,), name="Monitoring-migrate-priority", daemon=True, ) @@ -321,7 +324,7 @@ def start(self, self._node_queue_pull_thread = threading.Thread(target=self._migrate_logs_to_internal, args=( - node_queue, 'node', self._kill_event,), + node_queue, self._kill_event,), name="Monitoring-migrate-node", daemon=True, ) @@ -329,7 +332,7 @@ def start(self, self._block_queue_pull_thread = threading.Thread(target=self._migrate_logs_to_internal, args=( - block_queue, 'block', self._kill_event,), + block_queue, self._kill_event,), name="Monitoring-migrate-block", daemon=True, ) @@ -337,7 +340,7 @@ def start(self, self._resource_queue_pull_thread = threading.Thread(target=self._migrate_logs_to_internal, args=( - resource_queue, 'resource', self._kill_event,), + resource_queue, self._kill_event,), name="Monitoring-migrate-resource", daemon=True, ) @@ -351,18 +354,18 @@ def start(self, If that happens, the message will be added to deferred_resource_messages and processed later. """ - inserted_tasks = set() # type: Set[object] + inserted_tasks: Set[object] = set() """ like inserted_tasks but for task,try tuples """ - inserted_tries = set() # type: Set[Any] + inserted_tries: Set[Any] = set() # for any task ID, we can defer exactly one message, which is the # assumed-to-be-unique first message (with first message flag set). # The code prior to this patch will discard previous message in # the case of multiple messages to defer. - deferred_resource_messages = {} # type: MonitoringMessage + deferred_resource_messages: MonitoringMessage = {} exception_happened = False @@ -505,7 +508,7 @@ def start(self, "Got {} messages from block queue".format(len(block_info_messages))) # block_info_messages is possibly a nested list of dict (at different polling times) # Each dict refers to the info of a job/block at one polling time - block_messages_to_insert = [] # type: List[Any] + block_messages_to_insert: List[Any] = [] for block_msg in block_info_messages: block_messages_to_insert.extend(block_msg) self._insert(table=BLOCK, messages=block_messages_to_insert) @@ -574,43 +577,26 @@ def start(self, raise RuntimeError("An exception happened sometime during database processing and should have been logged in database_manager.log") @wrap_with_logs(target="database_manager") - def _migrate_logs_to_internal(self, logs_queue: queue.Queue, queue_tag: str, kill_event: threading.Event) -> None: - logger.info("Starting processing for queue {}".format(queue_tag)) + def _migrate_logs_to_internal(self, logs_queue: queue.Queue, kill_event: threading.Event) -> None: + logger.info("Starting _migrate_logs_to_internal") while not kill_event.is_set() or logs_queue.qsize() != 0: - logger.debug("""Checking STOP conditions for {} threads: {}, {}""" - .format(queue_tag, kill_event.is_set(), logs_queue.qsize() != 0)) + logger.debug("Checking STOP conditions: kill event: %s, queue has entries: %s", + kill_event.is_set(), logs_queue.qsize() != 0) try: x, addr = logs_queue.get(timeout=0.1) except queue.Empty: continue else: - if queue_tag == 'priority' and x == 'STOP': + if x == 'STOP': self.close() - elif queue_tag == 'priority': # implicitly not 'STOP' - assert isinstance(x, tuple) - assert len(x) == 2 - assert x[0] in [MessageType.WORKFLOW_INFO, MessageType.TASK_INFO], \ - "_migrate_logs_to_internal can only migrate WORKFLOW_,TASK_INFO message from priority queue, got x[0] == {}".format(x[0]) - self._dispatch_to_internal(x) - elif queue_tag == 'resource': - assert isinstance(x, tuple), "_migrate_logs_to_internal was expecting a tuple, got {}".format(x) - assert x[0] == MessageType.RESOURCE_INFO, ( - "_migrate_logs_to_internal can only migrate RESOURCE_INFO message from resource queue, " - "got tag {}, message {}".format(x[0], x) - ) - self._dispatch_to_internal(x) - elif queue_tag == 'node': - assert len(x) == 2, "expected message tuple to have exactly two elements" - assert x[0] == MessageType.NODE_INFO, "_migrate_logs_to_internal can only migrate NODE_INFO messages from node queue" - - self._dispatch_to_internal(x) - elif queue_tag == "block": - self._dispatch_to_internal(x) else: - logger.error(f"Discarding because unknown queue tag '{queue_tag}', message: {x}") + self._dispatch_to_internal(x) def _dispatch_to_internal(self, x: Tuple) -> None: + assert isinstance(x, tuple) + assert len(x) == 2, "expected message tuple to have exactly two elements" + if x[0] in [MessageType.WORKFLOW_INFO, MessageType.TASK_INFO]: self.pending_priority_queue.put(cast(Any, x)) elif x[0] == MessageType.RESOURCE_INFO: @@ -686,7 +672,7 @@ def _insert(self, table: str, messages: List[MonitoringMessage]) -> None: logger.exception("Rollback failed") def _get_messages_in_batch(self, msg_queue: "queue.Queue[X]") -> List[X]: - messages = [] # type: List[X] + messages: List[X] = [] start = time.time() while True: if time.time() - start >= self.batching_interval or len(messages) >= self.batching_threshold: @@ -719,11 +705,12 @@ def close(self) -> None: @wrap_with_logs(target="database_manager") -def dbm_starter(exception_q: "queue.Queue[Tuple[str, str]]", - priority_msgs: "queue.Queue[TaggedMonitoringMessage]", - node_msgs: "queue.Queue[MonitoringMessage]", - block_msgs: "queue.Queue[MonitoringMessage]", - resource_msgs: "queue.Queue[MonitoringMessage]", +@typeguard.typechecked +def dbm_starter(exception_q: mpq.Queue, + priority_msgs: mpq.Queue, + node_msgs: mpq.Queue, + block_msgs: mpq.Queue, + resource_msgs: mpq.Queue, db_url: str, logdir: str, logging_level: int) -> None: diff --git a/parsl/monitoring/errors.py b/parsl/monitoring/errors.py new file mode 100644 index 0000000000..f41225ff44 --- /dev/null +++ b/parsl/monitoring/errors.py @@ -0,0 +1,6 @@ +from parsl.errors import ParslError + + +class MonitoringHubStartError(ParslError): + def __str__(self) -> str: + return "Hub failed to start" diff --git a/parsl/monitoring/monitoring.py b/parsl/monitoring/monitoring.py index 8e4770a32a..a76e2cf487 100644 --- a/parsl/monitoring/monitoring.py +++ b/parsl/monitoring/monitoring.py @@ -12,8 +12,9 @@ import typeguard from parsl.log_utils import set_file_logger +from parsl.monitoring.errors import MonitoringHubStartError from parsl.monitoring.message_type import MessageType -from parsl.monitoring.radios import MultiprocessingQueueRadio +from parsl.monitoring.radios import MultiprocessingQueueRadioSender from parsl.monitoring.router import router_starter from parsl.monitoring.types import AddressedMonitoringMessage from parsl.multiprocessing import ForkProcess, SizedQueue @@ -105,7 +106,7 @@ def __init__(self, self.resource_monitoring_enabled = resource_monitoring_enabled self.resource_monitoring_interval = resource_monitoring_interval - def start(self, run_id: str, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> int: + def start(self, dfk_run_dir: str, config_run_dir: Union[str, os.PathLike]) -> None: logger.debug("Starting MonitoringHub") @@ -153,14 +154,18 @@ def start(self, run_id: str, dfk_run_dir: str, config_run_dir: Union[str, os.Pat self.router_exit_event = Event() self.router_proc = ForkProcess(target=router_starter, - args=(comm_q, self.exception_q, self.priority_msgs, self.node_msgs, - self.block_msgs, self.resource_msgs, self.router_exit_event), - kwargs={"hub_address": self.hub_address, + kwargs={"comm_q": comm_q, + "exception_q": self.exception_q, + "priority_msgs": self.priority_msgs, + "node_msgs": self.node_msgs, + "block_msgs": self.block_msgs, + "resource_msgs": self.resource_msgs, + "exit_event": self.router_exit_event, + "hub_address": self.hub_address, "udp_port": self.hub_port, "zmq_port_range": self.hub_port_range, "logdir": self.logdir, "logging_level": logging.DEBUG if self.monitoring_debug else logging.INFO, - "run_id": run_id }, name="Monitoring-Router-Process", daemon=True, @@ -187,7 +192,7 @@ def start(self, run_id: str, dfk_run_dir: str, config_run_dir: Union[str, os.Pat self.filesystem_proc.start() logger.info(f"Started filesystem radio receiver process {self.filesystem_proc.pid}") - self.radio = MultiprocessingQueueRadio(self.block_msgs) + self.radio = MultiprocessingQueueRadioSender(self.block_msgs) try: comm_q_result = comm_q.get(block=True, timeout=120) @@ -195,7 +200,7 @@ def start(self, run_id: str, dfk_run_dir: str, config_run_dir: Union[str, os.Pat comm_q.join_thread() except queue.Empty: logger.error("Hub has not completed initialization in 120s. Aborting") - raise Exception("Hub failed to start") + raise MonitoringHubStartError() if isinstance(comm_q_result, str): logger.error(f"MonitoringRouter sent an error message: {comm_q_result}") @@ -207,7 +212,7 @@ def start(self, run_id: str, dfk_run_dir: str, config_run_dir: Union[str, os.Pat logger.info("Monitoring Hub initialized") - return zmq_port + self.hub_zmq_port = zmq_port # TODO: tighten the Any message format def send(self, mtype: MessageType, message: Any) -> None: diff --git a/parsl/monitoring/radios.py b/parsl/monitoring/radios.py index 070869bdba..37bef0b06a 100644 --- a/parsl/monitoring/radios.py +++ b/parsl/monitoring/radios.py @@ -7,6 +7,8 @@ from multiprocessing.queues import Queue from typing import Optional +import zmq + from parsl.serialize import serialize _db_manager_excepts: Optional[Exception] @@ -15,14 +17,14 @@ logger = logging.getLogger(__name__) -class MonitoringRadio(metaclass=ABCMeta): +class MonitoringRadioSender(metaclass=ABCMeta): @abstractmethod def send(self, message: object) -> None: pass -class FilesystemRadio(MonitoringRadio): - """A MonitoringRadio that sends messages over a shared filesystem. +class FilesystemRadioSender(MonitoringRadioSender): + """A MonitoringRadioSender that sends messages over a shared filesystem. The messsage directory structure is based on maildir, https://en.wikipedia.org/wiki/Maildir @@ -36,7 +38,7 @@ class FilesystemRadio(MonitoringRadio): This avoids a race condition of reading partially written messages. This radio is likely to give higher shared filesystem load compared to - the UDPRadio, but should be much more reliable. + the UDP radio, but should be much more reliable. """ def __init__(self, *, monitoring_url: str, source_id: int, timeout: int = 10, run_dir: str): @@ -66,7 +68,7 @@ def send(self, message: object) -> None: os.rename(tmp_filename, new_filename) -class HTEXRadio(MonitoringRadio): +class HTEXRadioSender(MonitoringRadioSender): def __init__(self, monitoring_url: str, source_id: int, timeout: int = 10): """ @@ -120,7 +122,7 @@ def send(self, message: object) -> None: return -class UDPRadio(MonitoringRadio): +class UDPRadioSender(MonitoringRadioSender): def __init__(self, monitoring_url: str, source_id: int, timeout: int = 10): """ @@ -174,7 +176,7 @@ def send(self, message: object) -> None: return -class MultiprocessingQueueRadio(MonitoringRadio): +class MultiprocessingQueueRadioSender(MonitoringRadioSender): """A monitoring radio which connects over a multiprocessing Queue. This radio is intended to be used on the submit side, where components in the submit process, or processes launched by multiprocessing, will have @@ -186,3 +188,17 @@ def __init__(self, queue: Queue) -> None: def send(self, message: object) -> None: self.queue.put((message, 0)) + + +class ZMQRadioSender(MonitoringRadioSender): + """A monitoring radio which connects over ZMQ. This radio is not + thread-safe, because its use of ZMQ is not thread-safe. + """ + + def __init__(self, hub_address: str, hub_zmq_port: int) -> None: + self._hub_channel = zmq.Context().socket(zmq.DEALER) + self._hub_channel.set_hwm(0) + self._hub_channel.connect(f"tcp://{hub_address}:{hub_zmq_port}") + + def send(self, message: object) -> None: + self._hub_channel.send_pyobj(message) diff --git a/parsl/monitoring/remote.py b/parsl/monitoring/remote.py index 98168aa858..d374338dee 100644 --- a/parsl/monitoring/remote.py +++ b/parsl/monitoring/remote.py @@ -8,10 +8,10 @@ from parsl.monitoring.message_type import MessageType from parsl.monitoring.radios import ( - FilesystemRadio, - HTEXRadio, - MonitoringRadio, - UDPRadio, + FilesystemRadioSender, + HTEXRadioSender, + MonitoringRadioSender, + UDPRadioSender, ) from parsl.multiprocessing import ForkProcess from parsl.process_loggers import wrap_with_logs @@ -100,17 +100,17 @@ def wrapped(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return (wrapped, args, new_kwargs) -def get_radio(radio_mode: str, monitoring_hub_url: str, task_id: int, run_dir: str) -> MonitoringRadio: - radio: MonitoringRadio +def get_radio(radio_mode: str, monitoring_hub_url: str, task_id: int, run_dir: str) -> MonitoringRadioSender: + radio: MonitoringRadioSender if radio_mode == "udp": - radio = UDPRadio(monitoring_hub_url, - source_id=task_id) + radio = UDPRadioSender(monitoring_hub_url, + source_id=task_id) elif radio_mode == "htex": - radio = HTEXRadio(monitoring_hub_url, - source_id=task_id) + radio = HTEXRadioSender(monitoring_hub_url, + source_id=task_id) elif radio_mode == "filesystem": - radio = FilesystemRadio(monitoring_url=monitoring_hub_url, - source_id=task_id, run_dir=run_dir) + radio = FilesystemRadioSender(monitoring_url=monitoring_hub_url, + source_id=task_id, run_dir=run_dir) else: raise RuntimeError(f"Unknown radio mode: {radio_mode}") return radio @@ -199,10 +199,10 @@ def monitor(pid: int, pm = psutil.Process(pid) - children_user_time = {} # type: Dict[int, float] - children_system_time = {} # type: Dict[int, float] - children_num_ctx_switches_voluntary = {} # type: Dict[int, float] - children_num_ctx_switches_involuntary = {} # type: Dict[int, float] + children_user_time: Dict[int, float] = {} + children_system_time: Dict[int, float] = {} + children_num_ctx_switches_voluntary: Dict[int, float] = {} + children_num_ctx_switches_involuntary: Dict[int, float] = {} def accumulate_and_prepare() -> Dict[str, Any]: d = {"psutil_process_" + str(k): v for k, v in pm.as_dict().items() if k in simple} diff --git a/parsl/monitoring/router.py b/parsl/monitoring/router.py index 70b4862295..343410e3a4 100644 --- a/parsl/monitoring/router.py +++ b/parsl/monitoring/router.py @@ -1,14 +1,16 @@ from __future__ import annotations import logging +import multiprocessing.queues as mpq import os import pickle -import queue import socket +import threading import time from multiprocessing.synchronize import Event -from typing import Optional, Tuple, Union +from typing import Optional, Tuple +import typeguard import zmq from parsl.log_utils import set_file_logger @@ -30,9 +32,13 @@ def __init__(self, monitoring_hub_address: str = "127.0.0.1", logdir: str = ".", - run_id: str, logging_level: int = logging.INFO, - atexit_timeout: int = 3 # in seconds + atexit_timeout: int = 3, # in seconds + priority_msgs: mpq.Queue, + node_msgs: mpq.Queue, + block_msgs: mpq.Queue, + resource_msgs: mpq.Queue, + exit_event: Event, ): """ Initializes a monitoring configuration class. @@ -51,7 +57,11 @@ def __init__(self, Logging level as defined in the logging module. Default: logging.INFO atexit_timeout : float, optional The amount of time in seconds to terminate the hub without receiving any messages, after the last dfk workflow message is received. + *_msgs : Queue + Four multiprocessing queues to receive messages, routed by type tag, and sometimes modified according to type tag. + exit_event : Event + An event that the main Parsl process will set to signal that the monitoring router should shut down. """ os.makedirs(logdir, exist_ok=True) self.logger = set_file_logger("{}/monitoring_router.log".format(logdir), @@ -61,7 +71,6 @@ def __init__(self, self.hub_address = hub_address self.atexit_timeout = atexit_timeout - self.run_id = run_id self.loop_freq = 10.0 # milliseconds @@ -93,22 +102,60 @@ def __init__(self, min_port=zmq_port_range[0], max_port=zmq_port_range[1]) - def start(self, - priority_msgs: "queue.Queue[AddressedMonitoringMessage]", - node_msgs: "queue.Queue[AddressedMonitoringMessage]", - block_msgs: "queue.Queue[AddressedMonitoringMessage]", - resource_msgs: "queue.Queue[AddressedMonitoringMessage]", - exit_event: Event) -> None: + self.priority_msgs = priority_msgs + self.node_msgs = node_msgs + self.block_msgs = block_msgs + self.resource_msgs = resource_msgs + self.exit_event = exit_event + + @wrap_with_logs(target="monitoring_router") + def start(self) -> None: + self.logger.info("Starting UDP listener thread") + udp_radio_receiver_thread = threading.Thread(target=self.start_udp_listener, daemon=True) + udp_radio_receiver_thread.start() + + self.logger.info("Starting ZMQ listener thread") + zmq_radio_receiver_thread = threading.Thread(target=self.start_zmq_listener, daemon=True) + zmq_radio_receiver_thread.start() + + self.logger.info("Joining on ZMQ listener thread") + zmq_radio_receiver_thread.join() + self.logger.info("Joining on UDP listener thread") + udp_radio_receiver_thread.join() + self.logger.info("Joined on both ZMQ and UDP listener threads") + + @wrap_with_logs(target="monitoring_router") + def start_udp_listener(self) -> None: try: - while not exit_event.is_set(): + while not self.exit_event.is_set(): try: data, addr = self.udp_sock.recvfrom(2048) resource_msg = pickle.loads(data) self.logger.debug("Got UDP Message from {}: {}".format(addr, resource_msg)) - resource_msgs.put((resource_msg, addr)) + self.resource_msgs.put((resource_msg, addr)) except socket.timeout: pass + self.logger.info("UDP listener draining") + last_msg_received_time = time.time() + while time.time() - last_msg_received_time < self.atexit_timeout: + try: + data, addr = self.udp_sock.recvfrom(2048) + msg = pickle.loads(data) + self.logger.debug("Got UDP Message from {}: {}".format(addr, msg)) + self.resource_msgs.put((msg, addr)) + last_msg_received_time = time.time() + except socket.timeout: + pass + + self.logger.info("UDP listener finishing normally") + finally: + self.logger.info("UDP listener finished") + + @wrap_with_logs(target="monitoring_router") + def start_zmq_listener(self) -> None: + try: + while not self.exit_event.is_set(): try: dfk_loop_start = time.time() while time.time() - dfk_loop_start < 1.0: # TODO make configurable @@ -124,16 +171,15 @@ def start(self, msg_0 = (msg, 0) if msg[0] == MessageType.NODE_INFO: - msg[1]['run_id'] = self.run_id - node_msgs.put(msg_0) + self.node_msgs.put(msg_0) elif msg[0] == MessageType.RESOURCE_INFO: - resource_msgs.put(msg_0) + self.resource_msgs.put(msg_0) elif msg[0] == MessageType.BLOCK_INFO: - block_msgs.put(msg_0) + self.block_msgs.put(msg_0) elif msg[0] == MessageType.TASK_INFO: - priority_msgs.put(msg_0) + self.priority_msgs.put(msg_0) elif msg[0] == MessageType.WORKFLOW_INFO: - priority_msgs.put(msg_0) + self.priority_msgs.put(msg_0) else: # There is a type: ignore here because if msg[0] # is of the correct type, this code is unreachable, @@ -151,30 +197,20 @@ def start(self, # thing to do. self.logger.warning("Failure processing a ZMQ message", exc_info=True) - self.logger.info("Monitoring router draining") - last_msg_received_time = time.time() - while time.time() - last_msg_received_time < self.atexit_timeout: - try: - data, addr = self.udp_sock.recvfrom(2048) - msg = pickle.loads(data) - self.logger.debug("Got UDP Message from {}: {}".format(addr, msg)) - resource_msgs.put((msg, addr)) - last_msg_received_time = time.time() - except socket.timeout: - pass - - self.logger.info("Monitoring router finishing normally") + self.logger.info("ZMQ listener finishing normally") finally: - self.logger.info("Monitoring router finished") + self.logger.info("ZMQ listener finished") @wrap_with_logs -def router_starter(comm_q: "queue.Queue[Union[Tuple[int, int], str]]", - exception_q: "queue.Queue[Tuple[str, str]]", - priority_msgs: "queue.Queue[AddressedMonitoringMessage]", - node_msgs: "queue.Queue[AddressedMonitoringMessage]", - block_msgs: "queue.Queue[AddressedMonitoringMessage]", - resource_msgs: "queue.Queue[AddressedMonitoringMessage]", +@typeguard.typechecked +def router_starter(*, + comm_q: mpq.Queue, + exception_q: mpq.Queue, + priority_msgs: mpq.Queue, + node_msgs: mpq.Queue, + block_msgs: mpq.Queue, + resource_msgs: mpq.Queue, exit_event: Event, hub_address: str, @@ -182,8 +218,7 @@ def router_starter(comm_q: "queue.Queue[Union[Tuple[int, int], str]]", zmq_port_range: Tuple[int, int], logdir: str, - logging_level: int, - run_id: str) -> None: + logging_level: int) -> None: setproctitle("parsl: monitoring router") try: router = MonitoringRouter(hub_address=hub_address, @@ -191,7 +226,11 @@ def router_starter(comm_q: "queue.Queue[Union[Tuple[int, int], str]]", zmq_port_range=zmq_port_range, logdir=logdir, logging_level=logging_level, - run_id=run_id) + priority_msgs=priority_msgs, + node_msgs=node_msgs, + block_msgs=block_msgs, + resource_msgs=resource_msgs, + exit_event=exit_event) except Exception as e: logger.error("MonitoringRouter construction failed.", exc_info=True) comm_q.put(f"Monitoring router construction failed: {e}") @@ -200,7 +239,7 @@ def router_starter(comm_q: "queue.Queue[Union[Tuple[int, int], str]]", router.logger.info("Starting MonitoringRouter in router_starter") try: - router.start(priority_msgs, node_msgs, block_msgs, resource_msgs, exit_event) + router.start() except Exception as e: router.logger.exception("router.start exception") exception_q.put(('Hub', str(e))) diff --git a/parsl/providers/__init__.py b/parsl/providers/__init__.py index 475737f1f9..150f425f3d 100644 --- a/parsl/providers/__init__.py +++ b/parsl/providers/__init__.py @@ -1,6 +1,3 @@ -# Workstation Provider -from parsl.providers.ad_hoc.ad_hoc import AdHocProvider - # Cloud Providers from parsl.providers.aws.aws import AWSProvider from parsl.providers.azure.azure import AzureProvider @@ -24,7 +21,6 @@ 'SlurmProvider', 'TorqueProvider', 'LSFProvider', - 'AdHocProvider', 'PBSProProvider', 'AWSProvider', 'GoogleCloudProvider', diff --git a/parsl/providers/ad_hoc/ad_hoc.py b/parsl/providers/ad_hoc/ad_hoc.py index 207dd55738..9059648101 100644 --- a/parsl/providers/ad_hoc/ad_hoc.py +++ b/parsl/providers/ad_hoc/ad_hoc.py @@ -12,8 +12,12 @@ logger = logging.getLogger(__name__) -class AdHocProvider(ExecutionProvider, RepresentationMixin): - """ Ad-hoc execution provider +class DeprecatedAdHocProvider(ExecutionProvider, RepresentationMixin): + """ Deprecated ad-hoc execution provider + + The (former) AdHocProvider is deprecated. See + `issue #3515 `_ + for further discussion. This provider is used to provision execution resources over one or more ad hoc nodes that are each accessible over a Channel (say, ssh) but otherwise lack a cluster scheduler. diff --git a/parsl/providers/kubernetes/kube.py b/parsl/providers/kubernetes/kube.py index c5256a47f3..40b5b430a5 100644 --- a/parsl/providers/kubernetes/kube.py +++ b/parsl/providers/kubernetes/kube.py @@ -168,10 +168,9 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"): - tasks_per_node (int) : command invocations to be launched per node Kwargs: - - job_name (String): Name for job, must be unique + - job_name (String): Name for job Returns: - - None: At capacity, cannot provision more - job_id: (string) Identifier for the job """ @@ -187,7 +186,7 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"): formatted_cmd = template_string.format(command=cmd_string, worker_init=self.worker_init) - logger.debug("Pod name :{}".format(pod_name)) + logger.debug("Pod name: %s", pod_name) self._create_pod(image=self.image, pod_name=pod_name, job_name=job_name, @@ -243,13 +242,13 @@ def _status(self): for jid in to_poll_job_ids: phase = None try: - pod_status = self.kube_client.read_namespaced_pod_status(name=jid, namespace=self.namespace) + pod = self.kube_client.read_namespaced_pod(name=jid, namespace=self.namespace) except Exception: logger.exception("Failed to poll pod {} status, most likely because pod was terminated".format(jid)) if self.resources[jid]['status'] is JobStatus(JobState.RUNNING): phase = 'Unknown' else: - phase = pod_status.status.phase + phase = pod.status.phase if phase: status = translate_table.get(phase, JobState.UNKNOWN) logger.debug("Updating pod {} with status {} to parsl status {}".format(jid, @@ -286,7 +285,7 @@ def _create_pod(self, # Create the environment variables and command to initiate IPP environment_vars = client.V1EnvVar(name="TEST", value="SOME DATA") - launch_args = ["-c", "{0};".format(cmd_string)] + launch_args = ["-c", "{0}".format(cmd_string)] volume_mounts = [] # Create mount paths for the volumes diff --git a/parsl/tests/configs/ad_hoc_cluster_htex.py b/parsl/tests/configs/ad_hoc_cluster_htex.py deleted file mode 100644 index 0949b82392..0000000000 --- a/parsl/tests/configs/ad_hoc_cluster_htex.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Any, Dict - -from parsl.channels import SSHChannel -from parsl.config import Config -from parsl.executors import HighThroughputExecutor -from parsl.providers import AdHocProvider - -user_opts = {'adhoc': - {'username': 'YOUR_USERNAME', - 'script_dir': 'YOUR_SCRIPT_DIR', - 'remote_hostnames': ['REMOTE_HOST_URL_1', 'REMOTE_HOST_URL_2'] - } - } # type: Dict[str, Dict[str, Any]] - -config = Config( - executors=[ - HighThroughputExecutor( - label='remote_htex', - max_workers_per_node=2, - worker_logdir_root=user_opts['adhoc']['script_dir'], - encrypted=True, - provider=AdHocProvider( - # Command to be run before starting a worker, such as: - # 'module load Anaconda; source activate parsl_env'. - worker_init='', - channels=[SSHChannel(hostname=m, - username=user_opts['adhoc']['username'], - script_dir=user_opts['adhoc']['script_dir'], - ) for m in user_opts['adhoc']['remote_hostnames']] - ) - ) - ], - # AdHoc Clusters should not be setup with scaling strategy. - strategy='none', -) diff --git a/parsl/tests/configs/flux_local.py b/parsl/tests/configs/flux_local.py new file mode 100644 index 0000000000..203dd590c0 --- /dev/null +++ b/parsl/tests/configs/flux_local.py @@ -0,0 +1,11 @@ +from parsl.config import Config +from parsl.executors import FluxExecutor + + +def fresh_config(): + return Config( + executors=[FluxExecutor()], + ) + + +config = fresh_config() diff --git a/parsl/tests/configs/htex_ad_hoc_cluster.py b/parsl/tests/configs/htex_ad_hoc_cluster.py deleted file mode 100644 index db24b42ab2..0000000000 --- a/parsl/tests/configs/htex_ad_hoc_cluster.py +++ /dev/null @@ -1,26 +0,0 @@ -from parsl.channels import SSHChannel -from parsl.config import Config -from parsl.executors import HighThroughputExecutor -from parsl.providers import AdHocProvider -from parsl.tests.configs.user_opts import user_opts - -config = Config( - executors=[ - HighThroughputExecutor( - label='remote_htex', - cores_per_worker=1, - worker_debug=False, - address=user_opts['public_ip'], - encrypted=True, - provider=AdHocProvider( - move_files=False, - parallelism=1, - worker_init=user_opts['adhoc']['worker_init'], - channels=[SSHChannel(hostname=m, - username=user_opts['adhoc']['username'], - script_dir=user_opts['adhoc']['script_dir'], - ) for m in user_opts['adhoc']['remote_hostnames']] - ) - ) - ], -) diff --git a/parsl/tests/configs/local_adhoc.py b/parsl/tests/configs/local_adhoc.py index 25b1f38d61..9b1f951842 100644 --- a/parsl/tests/configs/local_adhoc.py +++ b/parsl/tests/configs/local_adhoc.py @@ -1,7 +1,7 @@ from parsl.channels import LocalChannel from parsl.config import Config from parsl.executors import HighThroughputExecutor -from parsl.providers import AdHocProvider +from parsl.providers.ad_hoc.ad_hoc import DeprecatedAdHocProvider def fresh_config(): @@ -10,7 +10,7 @@ def fresh_config(): HighThroughputExecutor( label='AdHoc', encrypted=True, - provider=AdHocProvider( + provider=DeprecatedAdHocProvider( channels=[LocalChannel(), LocalChannel()] ) ) diff --git a/parsl/tests/configs/swan_htex.py b/parsl/tests/configs/swan_htex.py deleted file mode 100644 index 3b1b6785ab..0000000000 --- a/parsl/tests/configs/swan_htex.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -================== Block -| ++++++++++++++ | Node -| | | | -| | Task | | . . . -| | | | -| ++++++++++++++ | -================== -""" -from parsl.channels import SSHChannel -from parsl.config import Config -from parsl.executors import HighThroughputExecutor -from parsl.launchers import AprunLauncher -from parsl.providers import TorqueProvider - -# If you are a developer running tests, make sure to update parsl/tests/configs/user_opts.py -# If you are a user copying-and-pasting this as an example, make sure to either -# 1) create a local `user_opts.py`, or -# 2) delete the user_opts import below and replace all appearances of `user_opts` with the literal value -# (i.e., user_opts['swan']['username'] -> 'your_username') -from .user_opts import user_opts - -config = Config( - executors=[ - HighThroughputExecutor( - label='swan_htex', - encrypted=True, - provider=TorqueProvider( - channel=SSHChannel( - hostname='swan.cray.com', - username=user_opts['swan']['username'], - script_dir=user_opts['swan']['script_dir'], - ), - nodes_per_block=1, - init_blocks=1, - max_blocks=1, - launcher=AprunLauncher(), - scheduler_options=user_opts['swan']['scheduler_options'], - worker_init=user_opts['swan']['worker_init'], - ), - ) - ] -) diff --git a/parsl/tests/conftest.py b/parsl/tests/conftest.py index 80b9e000cd..638088c44c 100644 --- a/parsl/tests/conftest.py +++ b/parsl/tests/conftest.py @@ -151,6 +151,10 @@ def pytest_configure(config): 'markers', 'multiple_cores_required: Marks tests that require multiple cores, such as htex affinity' ) + config.addinivalue_line( + 'markers', + 'unix_filesystem_permissions_required: Marks tests that require unix-level filesystem permission enforcement' + ) config.addinivalue_line( 'markers', 'issue3328: Marks tests broken by issue #3328' diff --git a/parsl/tests/integration/test_channels/test_scp_1.py b/parsl/tests/integration/test_channels/test_scp_1.py deleted file mode 100644 index c11df3c663..0000000000 --- a/parsl/tests/integration/test_channels/test_scp_1.py +++ /dev/null @@ -1,45 +0,0 @@ -import os - -from parsl.channels.ssh.ssh import SSHChannel as SSH - - -def connect_and_list(hostname, username): - out = '' - conn = SSH(hostname, username=username) - conn.push_file(os.path.abspath('remote_run.sh'), '/home/davidk/') - # ec, out, err = conn.execute_wait("ls /tmp/remote_run.sh; bash /tmp/remote_run.sh") - conn.close() - return out - - -script = '''#!/bin/bash -echo "Hostname: $HOSTNAME" -echo "Cpu info -----" -cat /proc/cpuinfo -echo "Done----------" -''' - - -def test_connect_1(): - with open('remote_run.sh', 'w') as f: - f.write(script) - - sites = { - 'midway': { - 'url': 'midway.rcc.uchicago.edu', - 'uname': 'yadunand' - }, - 'swift': { - 'url': 'swift.rcc.uchicago.edu', - 'uname': 'yadunand' - } - } - - for site in sites.values(): - out = connect_and_list(site['url'], site['uname']) - print("Sitename :{0} hostname:{1}".format(site['url'], out)) - - -if __name__ == "__main__": - - test_connect_1() diff --git a/parsl/tests/integration/test_channels/test_ssh_1.py b/parsl/tests/integration/test_channels/test_ssh_1.py deleted file mode 100644 index 61ab3f2705..0000000000 --- a/parsl/tests/integration/test_channels/test_ssh_1.py +++ /dev/null @@ -1,40 +0,0 @@ -from parsl.channels.ssh.ssh import SSHChannel as SSH - - -def connect_and_list(hostname, username): - conn = SSH(hostname, username=username) - ec, out, err = conn.execute_wait("echo $HOSTNAME") - conn.close() - return out - - -def test_midway(): - ''' Test ssh channels to midway - ''' - url = 'midway.rcc.uchicago.edu' - uname = 'yadunand' - out = connect_and_list(url, uname) - print("Sitename :{0} hostname:{1}".format(url, out)) - - -def test_beagle(): - ''' Test ssh channels to beagle - ''' - url = 'login04.beagle.ci.uchicago.edu' - uname = 'yadunandb' - out = connect_and_list(url, uname) - print("Sitename :{0} hostname:{1}".format(url, out)) - - -def test_osg(): - ''' Test ssh connectivity to osg - ''' - url = 'login.osgconnect.net' - uname = 'yadunand' - out = connect_and_list(url, uname) - print("Sitename :{0} hostname:{1}".format(url, out)) - - -if __name__ == "__main__": - - pass diff --git a/parsl/tests/integration/test_channels/test_ssh_errors.py b/parsl/tests/integration/test_channels/test_ssh_errors.py deleted file mode 100644 index 7483e30a5c..0000000000 --- a/parsl/tests/integration/test_channels/test_ssh_errors.py +++ /dev/null @@ -1,46 +0,0 @@ -from parsl.channels.errors import BadHostKeyException, SSHException -from parsl.channels.ssh.ssh import SSHChannel as SSH - - -def connect_and_list(hostname, username): - conn = SSH(hostname, username=username) - ec, out, err = conn.execute_wait("echo $HOSTNAME") - conn.close() - return out - - -def test_error_1(): - try: - connect_and_list("bad.url.gov", "ubuntu") - except Exception as e: - assert type(e) is SSHException, "Expected SSException, got: {0}".format(e) - - -def test_error_2(): - try: - connect_and_list("swift.rcc.uchicago.edu", "mango") - except SSHException: - print("Caught the right exception") - else: - raise Exception("Expected SSException, got: {0}".format(e)) - - -def test_error_3(): - ''' This should work - ''' - try: - connect_and_list("edison.nersc.gov", "yadunand") - except BadHostKeyException as e: - print("Caught exception BadHostKeyException: ", e) - else: - assert False, "Expected SSException, got: {0}".format(e) - - -if __name__ == "__main__": - - tests = [test_error_1, test_error_2, test_error_3] - - for test in tests: - print("---------Running : {0}---------------".format(test)) - test() - print("----------------------DONE--------------------------") diff --git a/parsl/tests/integration/test_channels/test_ssh_file_transport.py b/parsl/tests/integration/test_channels/test_ssh_file_transport.py deleted file mode 100644 index 61672c3ff5..0000000000 --- a/parsl/tests/integration/test_channels/test_ssh_file_transport.py +++ /dev/null @@ -1,41 +0,0 @@ -import parsl -from parsl.channels.ssh.ssh import SSHChannel as SSH - - -def connect_and_list(hostname, username): - conn = SSH(hostname, username=username) - ec, out, err = conn.execute_wait("echo $HOSTNAME") - conn.close() - return out - - -def test_push(conn, fname="test001.txt"): - - with open(fname, 'w') as f: - f.write("Hello from parsl.ssh testing\n") - - conn.push_file(fname, "/tmp") - ec, out, err = conn.execute_wait("ls /tmp/{0}".format(fname)) - print(ec, out, err) - - -def test_pull(conn, fname="test001.txt"): - - local = "foo" - conn.pull_file("/tmp/{0}".format(fname), local) - - with open("{0}/{1}".format(local, fname), 'r') as f: - print(f.readlines()) - - -if __name__ == "__main__": - - parsl.set_stream_logger() - - # This is for testing - conn = SSH("midway.rcc.uchicago.edu", username="yadunand") - - test_push(conn) - test_pull(conn) - - conn.close() diff --git a/parsl/tests/integration/test_channels/test_ssh_interactive.py b/parsl/tests/integration/test_channels/test_ssh_interactive.py deleted file mode 100644 index c6f9b9dea9..0000000000 --- a/parsl/tests/integration/test_channels/test_ssh_interactive.py +++ /dev/null @@ -1,24 +0,0 @@ -import parsl -from parsl.channels.ssh_il.ssh_il import SSHInteractiveLoginChannel as SSH - - -def connect_and_list(hostname, username): - conn = SSH(hostname, username=username) - ec, out, err = conn.execute_wait("echo $HOSTNAME") - conn.close() - return out - - -def test_cooley(): - ''' Test ssh channels to midway - ''' - url = 'cooley.alcf.anl.gov' - uname = 'yadunand' - out = connect_and_list(url, uname) - print("Sitename :{0} hostname:{1}".format(url, out)) - return - - -if __name__ == "__main__": - parsl.set_stream_logger() - test_cooley() diff --git a/parsl/tests/manual_tests/test_ad_hoc_htex.py b/parsl/tests/manual_tests/test_ad_hoc_htex.py deleted file mode 100644 index dfa34ec0d1..0000000000 --- a/parsl/tests/manual_tests/test_ad_hoc_htex.py +++ /dev/null @@ -1,49 +0,0 @@ -import parsl -from parsl import python_app - -parsl.set_stream_logger() - -from parsl.channels import SSHChannel -from parsl.config import Config -from parsl.executors import HighThroughputExecutor -from parsl.providers import AdHocProvider - -remotes = ['midway2-login2.rcc.uchicago.edu', 'midway2-login1.rcc.uchicago.edu'] - -config = Config( - executors=[ - HighThroughputExecutor( - label='AdHoc', - max_workers_per_node=2, - worker_logdir_root="/scratch/midway2/yadunand/parsl_scripts", - encrypted=True, - provider=AdHocProvider( - worker_init="source /scratch/midway2/yadunand/parsl_env_setup.sh", - channels=[SSHChannel(hostname=m, - username="yadunand", - script_dir="/scratch/midway2/yadunand/parsl_cluster") - for m in remotes] - ) - ) - ] -) - - -@python_app -def platform(sleep=2, stdout=None): - import platform - import time - time.sleep(sleep) - return platform.uname() - - -def test_raw_provider(): - - parsl.load(config) - - x = [platform() for i in range(10)] - print([i.result() for i in x]) - - -if __name__ == "__main__": - test_raw_provider() diff --git a/parsl/tests/manual_tests/test_oauth_ssh.py b/parsl/tests/manual_tests/test_oauth_ssh.py deleted file mode 100644 index 3d464bcc0e..0000000000 --- a/parsl/tests/manual_tests/test_oauth_ssh.py +++ /dev/null @@ -1,13 +0,0 @@ -from parsl.channels import OAuthSSHChannel - - -def test_channel(): - channel = OAuthSSHChannel(hostname='ssh.demo.globus.org', username='yadunand') - x, stdout, stderr = channel.execute_wait('ls') - print(x, stdout, stderr) - assert x == 0, "Expected exit code 0, got {}".format(x) - - -if __name__ == '__main__': - - test_channel() diff --git a/parsl/tests/test_bash_apps/test_inputs_default.py b/parsl/tests/test_bash_apps/test_inputs_default.py new file mode 100644 index 0000000000..9b6d7a18a2 --- /dev/null +++ b/parsl/tests/test_bash_apps/test_inputs_default.py @@ -0,0 +1,25 @@ +import pytest + +from parsl import AUTO_LOGNAME, Config, bash_app, python_app +from parsl.executors import ThreadPoolExecutor + + +def local_config(): + return Config(executors=[ThreadPoolExecutor()]) + + +@pytest.mark.local +def test_default_inputs(): + @python_app + def identity(inp): + return inp + + @bash_app + def sum_inputs(inputs=[identity(1), identity(2)], stdout=AUTO_LOGNAME): + calc = sum(inputs) + return f"echo {calc}" + + fut = sum_inputs() + fut.result() + with open(fut.stdout, 'r') as f: + assert int(f.read()) == 3 diff --git a/parsl/tests/test_bash_apps/test_memoize_ignore_args.py b/parsl/tests/test_bash_apps/test_memoize_ignore_args.py index 0439bfb163..ee3917e561 100644 --- a/parsl/tests/test_bash_apps/test_memoize_ignore_args.py +++ b/parsl/tests/test_bash_apps/test_memoize_ignore_args.py @@ -1,7 +1,5 @@ import os -import pytest - import parsl from parsl.app.app import bash_app @@ -23,24 +21,18 @@ def no_checkpoint_stdout_app_ignore_args(stdout=None): return "echo X" -def test_memo_stdout(): +def test_memo_stdout(tmpd_cwd): + path_x = tmpd_cwd / "test.memo.stdout.x" # this should run and create a file named after path_x - path_x = "test.memo.stdout.x" - if os.path.exists(path_x): - os.remove(path_x) - - no_checkpoint_stdout_app_ignore_args(stdout=path_x).result() - assert os.path.exists(path_x) - - # this should be memoized, so not create benc.test.y - path_y = "test.memo.stdout.y" + no_checkpoint_stdout_app_ignore_args(stdout=str(path_x)).result() + assert path_x.exists() - if os.path.exists(path_y): - os.remove(path_y) + # this should be memoized, so should not get created + path_y = tmpd_cwd / "test.memo.stdout.y" no_checkpoint_stdout_app_ignore_args(stdout=path_y).result() - assert not os.path.exists(path_y) + assert not path_y.exists(), "For memoization, expected NO file written" # this should also be memoized, so not create an arbitrary name z_fut = no_checkpoint_stdout_app_ignore_args(stdout=parsl.AUTO_LOGNAME) diff --git a/parsl/tests/test_bash_apps/test_memoize_ignore_args_regr.py b/parsl/tests/test_bash_apps/test_memoize_ignore_args_regr.py index 3c9b51e980..8f03c055a1 100644 --- a/parsl/tests/test_bash_apps/test_memoize_ignore_args_regr.py +++ b/parsl/tests/test_bash_apps/test_memoize_ignore_args_regr.py @@ -1,5 +1,4 @@ import copy -import os from typing import List import pytest @@ -30,21 +29,17 @@ def no_checkpoint_stdout_app(stdout=None): return "echo X" -def test_memo_stdout(): - +def test_memo_stdout(tmpd_cwd): assert const_list_x == const_list_x_arg - path_x = "test.memo.stdout.x" - if os.path.exists(path_x): - os.remove(path_x) + path_x = tmpd_cwd / "test.memo.stdout.x" # this should run and create a file named after path_x - no_checkpoint_stdout_app(stdout=path_x).result() - assert os.path.exists(path_x) + no_checkpoint_stdout_app(stdout=str(path_x)).result() + path_x.unlink(missing_ok=False) - os.remove(path_x) - no_checkpoint_stdout_app(stdout=path_x).result() - assert not os.path.exists(path_x) + no_checkpoint_stdout_app(stdout=str(path_x)).result() + assert not path_x.exists(), "For memoization, expected NO file written" # this should also be memoized, so not create an arbitrary name z_fut = no_checkpoint_stdout_app(stdout=parsl.AUTO_LOGNAME) diff --git a/parsl/tests/test_bash_apps/test_stdout.py b/parsl/tests/test_bash_apps/test_stdout.py index b1efadd445..eba6a7b80d 100644 --- a/parsl/tests/test_bash_apps/test_stdout.py +++ b/parsl/tests/test_bash_apps/test_stdout.py @@ -16,7 +16,6 @@ def echo_to_streams(msg, stderr=None, stdout=None): whitelist = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', '*threads*') speclist = ( - '/bad/dir/t.out', ['t3.out', 'w'], ('t4.out', None), (42, 'w'), @@ -26,7 +25,6 @@ def echo_to_streams(msg, stderr=None, stdout=None): ) testids = [ - 'nonexistent_dir', 'list_not_tuple', 'null_mode', 'not_a_string', @@ -55,6 +53,26 @@ def test_bad_stdout_specs(spec): @pytest.mark.issue3328 +@pytest.mark.unix_filesystem_permissions_required +def test_bad_stdout_file(): + """Testing bad stderr file""" + + o = "/bad/dir/t2.out" + + fn = echo_to_streams("Hello world", stdout=o, stderr='t.err') + + try: + fn.result() + except perror.BadStdStreamFile: + pass + else: + assert False, "Did not raise expected exception BadStdStreamFile" + + return + + +@pytest.mark.issue3328 +@pytest.mark.unix_filesystem_permissions_required def test_bad_stderr_file(): """Testing bad stderr file""" diff --git a/parsl/tests/test_channels/test_dfk_close.py b/parsl/tests/test_channels/test_dfk_close.py new file mode 100644 index 0000000000..05b2e9395f --- /dev/null +++ b/parsl/tests/test_channels/test_dfk_close.py @@ -0,0 +1,26 @@ +from unittest.mock import Mock + +import pytest + +import parsl +from parsl.channels.base import Channel +from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider + + +@pytest.mark.local +def test_dfk_close(): + + mock_channel = Mock(spec=Channel) + + # block settings all 0 because the mock channel won't be able to + # do anything to make a block exist + p = LocalProvider(channel=mock_channel, init_blocks=0, min_blocks=0, max_blocks=0) + + e = HighThroughputExecutor(provider=p) + + c = parsl.Config(executors=[e]) + with parsl.load(c): + pass + + assert mock_channel.close.called diff --git a/parsl/tests/test_error_handling/test_retries.py b/parsl/tests/test_error_handling/test_retries.py index c069ee7ba7..06ae81702e 100644 --- a/parsl/tests/test_error_handling/test_retries.py +++ b/parsl/tests/test_error_handling/test_retries.py @@ -1,9 +1,7 @@ -import argparse import os import pytest -import parsl from parsl import bash_app, python_app from parsl.tests.configs.local_threads import fresh_config @@ -68,8 +66,6 @@ def test_fail_nowait(numtasks=10): assert isinstance( e, TypeError), "Expected a TypeError, got {}".format(e) - print("Done") - @pytest.mark.local def test_fail_delayed(numtasks=10): @@ -94,19 +90,12 @@ def test_fail_delayed(numtasks=10): assert isinstance( e, TypeError), "Expected a TypeError, got {}".format(e) - print("Done") - @pytest.mark.local -def test_retry(): +def test_retry(tmpd_cwd): """Test retries via app that succeeds on the Nth retry. """ - fname = "retry.out" - try: - os.remove(fname) - except OSError: - pass - fu = succeed_on_retry(fname) - - fu.result() + fpath = tmpd_cwd / "retry.out" + sout = str(tmpd_cwd / "stdout") + succeed_on_retry(str(fpath), stdout=sout).result() diff --git a/parsl/tests/test_htex/test_disconnected_blocks_failing_provider.py b/parsl/tests/test_htex/test_disconnected_blocks_failing_provider.py new file mode 100644 index 0000000000..b2fa507aca --- /dev/null +++ b/parsl/tests/test_htex/test_disconnected_blocks_failing_provider.py @@ -0,0 +1,71 @@ +import logging + +import pytest + +import parsl +from parsl import Config +from parsl.executors import HighThroughputExecutor +from parsl.executors.errors import BadStateException +from parsl.jobs.states import JobState, JobStatus +from parsl.providers import LocalProvider + + +class FailingProvider(LocalProvider): + def submit(*args, **kwargs): + raise RuntimeError("Deliberate failure of provider.submit") + + +def local_config(): + """Config to simulate failing blocks without connecting""" + return Config( + executors=[ + HighThroughputExecutor( + label="HTEX", + heartbeat_period=1, + heartbeat_threshold=2, + poll_period=100, + max_workers_per_node=1, + provider=FailingProvider( + init_blocks=0, + max_blocks=2, + min_blocks=0, + ), + ) + ], + max_idletime=0.5, + strategy='htex_auto_scale', + strategy_period=0.1 + # this strategy period needs to be a few times smaller than the + # status_polling_interval of FailingProvider, which is 5s at + # time of writing + ) + + +@parsl.python_app +def double(x): + return x * 2 + + +@pytest.mark.local +def test_disconnected_blocks(): + """Test reporting of blocks that fail to connect from HTEX""" + dfk = parsl.dfk() + executor = dfk.executors["HTEX"] + + connected_blocks = executor.connected_blocks() + assert not connected_blocks, "Expected 0 blocks" + + future = double(5) + with pytest.raises(BadStateException): + future.result() + + assert isinstance(future.exception(), BadStateException) + + status_dict = executor.status() + assert len(status_dict) == 1, "Expected exactly 1 block" + for status in status_dict.values(): + assert isinstance(status, JobStatus) + assert status.state == JobState.MISSING + + connected_blocks = executor.connected_blocks() + assert connected_blocks == [], "Expected exactly 0 connected blocks" diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py index ca95773e1b..810236c1b4 100644 --- a/parsl/tests/test_htex/test_htex.py +++ b/parsl/tests/test_htex/test_htex.py @@ -1,11 +1,12 @@ +import logging import pathlib -import warnings +from subprocess import Popen, TimeoutExpired +from typing import Optional, Sequence from unittest import mock import pytest from parsl import HighThroughputExecutor, curvezmq -from parsl.multiprocessing import ForkProcess _MOCK_BASE = "parsl.executors.high_throughput.executor" @@ -71,42 +72,58 @@ def test_htex_start_encrypted( @pytest.mark.local @pytest.mark.parametrize("started", (True, False)) @pytest.mark.parametrize("timeout_expires", (True, False)) -@mock.patch(f"{_MOCK_BASE}.logger") def test_htex_shutdown( - mock_logger: mock.MagicMock, started: bool, timeout_expires: bool, htex: HighThroughputExecutor, + caplog ): - mock_ix_proc = mock.Mock(spec=ForkProcess) + mock_ix_proc = mock.Mock(spec=Popen) if started: htex.interchange_proc = mock_ix_proc - mock_ix_proc.is_alive.return_value = True + + # This will, in the absence of any exit trigger, block forever if + # no timeout is given and if the interchange does not terminate. + # Raise an exception to report that, rather than actually block, + # and hope that nothing is catching that exception. + + # this function implements the behaviour if the interchange has + # not received a termination call + def proc_wait_alive(timeout): + if timeout: + raise TimeoutExpired(cmd="mock-interchange", timeout=timeout) + else: + raise RuntimeError("This wait call would hang forever") + + def proc_wait_terminated(timeout): + return 0 + + mock_ix_proc.wait.side_effect = proc_wait_alive if not timeout_expires: # Simulate termination of the Interchange process def kill_interchange(*args, **kwargs): - mock_ix_proc.is_alive.return_value = False + mock_ix_proc.wait.side_effect = proc_wait_terminated mock_ix_proc.terminate.side_effect = kill_interchange - htex.shutdown() + with caplog.at_level(logging.INFO): + htex.shutdown() - mock_logs = mock_logger.info.call_args_list if started: assert mock_ix_proc.terminate.called - assert mock_ix_proc.join.called - assert {"timeout": 10} == mock_ix_proc.join.call_args[1] + assert mock_ix_proc.wait.called + assert {"timeout": 10} == mock_ix_proc.wait.call_args[1] if timeout_expires: - assert "Unable to terminate Interchange" in mock_logs[1][0][0] + assert "Unable to terminate Interchange" in caplog.text assert mock_ix_proc.kill.called - assert "Attempting" in mock_logs[0][0][0] - assert "Finished" in mock_logs[-1][0][0] + assert "Attempting HighThroughputExecutor shutdown" in caplog.text + assert "Finished HighThroughputExecutor shutdown" in caplog.text else: assert not mock_ix_proc.terminate.called - assert not mock_ix_proc.join.called - assert "has not started" in mock_logs[0][0][0] + assert not mock_ix_proc.wait.called + assert "HighThroughputExecutor has not started" in caplog.text @pytest.mark.local @@ -119,3 +136,25 @@ def test_max_workers_per_node(): # Ensure max_workers_per_node takes precedence assert htex.max_workers_per_node == htex.max_workers == 1 + + +@pytest.mark.local +@pytest.mark.parametrize("cmd", (None, "custom-launch-cmd")) +def test_htex_worker_pool_launch_cmd(cmd: Optional[str]): + if cmd: + htex = HighThroughputExecutor(launch_cmd=cmd) + assert htex.launch_cmd == cmd + else: + htex = HighThroughputExecutor() + assert htex.launch_cmd.startswith("process_worker_pool.py") + + +@pytest.mark.local +@pytest.mark.parametrize("cmd", (None, ["custom", "launch", "cmd"])) +def test_htex_interchange_launch_cmd(cmd: Optional[Sequence[str]]): + if cmd: + htex = HighThroughputExecutor(interchange_launch_cmd=cmd) + assert htex.interchange_launch_cmd == cmd + else: + htex = HighThroughputExecutor() + assert htex.interchange_launch_cmd == ["interchange.py"] diff --git a/parsl/tests/test_htex/test_resource_spec_validation.py b/parsl/tests/test_htex/test_resource_spec_validation.py new file mode 100644 index 0000000000..ac0c580c20 --- /dev/null +++ b/parsl/tests/test_htex/test_resource_spec_validation.py @@ -0,0 +1,40 @@ +import queue +from unittest import mock + +import pytest + +from parsl.executors import HighThroughputExecutor +from parsl.executors.high_throughput.mpi_prefix_composer import ( + InvalidResourceSpecification, +) + + +def double(x): + return x * 2 + + +@pytest.mark.local +def test_submit_calls_validate(): + + htex = HighThroughputExecutor() + htex.outgoing_q = mock.Mock(spec=queue.Queue) + htex.validate_resource_spec = mock.Mock(spec=htex.validate_resource_spec) + + res_spec = {} + htex.submit(double, res_spec, (5,), {}) + htex.validate_resource_spec.assert_called() + + +@pytest.mark.local +def test_resource_spec_validation(): + htex = HighThroughputExecutor() + ret_val = htex.validate_resource_spec({}) + assert ret_val is None + + +@pytest.mark.local +def test_resource_spec_validation_bad_keys(): + htex = HighThroughputExecutor() + + with pytest.raises(InvalidResourceSpecification): + htex.validate_resource_spec({"num_nodes": 2}) diff --git a/parsl/tests/test_htex/test_zmq_binding.py b/parsl/tests/test_htex/test_zmq_binding.py index eaf2e9731b..e21c065d0d 100644 --- a/parsl/tests/test_htex/test_zmq_binding.py +++ b/parsl/tests/test_htex/test_zmq_binding.py @@ -1,3 +1,4 @@ +import logging import pathlib from typing import Optional from unittest import mock @@ -8,6 +9,24 @@ from parsl import curvezmq from parsl.executors.high_throughput.interchange import Interchange +from parsl.executors.high_throughput.manager_selector import RandomManagerSelector + + +def make_interchange(*, interchange_address: Optional[str], cert_dir: Optional[str]) -> Interchange: + return Interchange(interchange_address=interchange_address, + cert_dir=cert_dir, + client_address="127.0.0.1", + client_ports=(50055, 50056, 50057), + worker_ports=None, + worker_port_range=(54000, 55000), + hub_address=None, + hub_zmq_port=None, + heartbeat_threshold=60, + logdir=".", + logging_level=logging.INFO, + manager_selector=RandomManagerSelector(), + poll_period=10, + run_id="test_run_id") @pytest.fixture @@ -31,7 +50,7 @@ def test_interchange_curvezmq_sockets( mock_socket: mock.MagicMock, cert_dir: Optional[str], encrypted: bool ): address = "127.0.0.1" - ix = Interchange(interchange_address=address, cert_dir=cert_dir) + ix = make_interchange(interchange_address=address, cert_dir=cert_dir) assert isinstance(ix.zmq_context, curvezmq.ServerContext) assert ix.zmq_context.encrypted is encrypted assert mock_socket.call_count == 5 @@ -40,7 +59,7 @@ def test_interchange_curvezmq_sockets( @pytest.mark.local @pytest.mark.parametrize("encrypted", (True, False), indirect=True) def test_interchange_binding_no_address(cert_dir: Optional[str]): - ix = Interchange(cert_dir=cert_dir) + ix = make_interchange(interchange_address=None, cert_dir=cert_dir) assert ix.interchange_address == "*" @@ -49,7 +68,7 @@ def test_interchange_binding_no_address(cert_dir: Optional[str]): def test_interchange_binding_with_address(cert_dir: Optional[str]): # Using loopback address address = "127.0.0.1" - ix = Interchange(interchange_address=address, cert_dir=cert_dir) + ix = make_interchange(interchange_address=address, cert_dir=cert_dir) assert ix.interchange_address == address @@ -60,7 +79,7 @@ def test_interchange_binding_with_non_ipv4_address(cert_dir: Optional[str]): # Confirm that a ipv4 address is required address = "localhost" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address, cert_dir=cert_dir) + make_interchange(interchange_address=address, cert_dir=cert_dir) @pytest.mark.local @@ -69,7 +88,7 @@ def test_interchange_binding_bad_address(cert_dir: Optional[str]): """Confirm that we raise a ZMQError when a bad address is supplied""" address = "550.0.0.0" with pytest.raises(zmq.error.ZMQError): - Interchange(interchange_address=address, cert_dir=cert_dir) + make_interchange(interchange_address=address, cert_dir=cert_dir) @pytest.mark.local @@ -77,7 +96,7 @@ def test_interchange_binding_bad_address(cert_dir: Optional[str]): def test_limited_interface_binding(cert_dir: Optional[str]): """When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" address = "127.0.0.1" - ix = Interchange(interchange_address=address, cert_dir=cert_dir) + ix = make_interchange(interchange_address=address, cert_dir=cert_dir) ix.worker_result_port proc = psutil.Process() conns = proc.connections(kind="tcp") diff --git a/parsl/tests/test_monitoring/test_basic.py b/parsl/tests/test_monitoring/test_basic.py index c900670ec8..1c792a9d82 100644 --- a/parsl/tests/test_monitoring/test_basic.py +++ b/parsl/tests/test_monitoring/test_basic.py @@ -25,10 +25,23 @@ def this_app(): # a configuration that is suitably configured for monitoring. def htex_config(): + """This config will use htex's default htex-specific monitoring radio mode""" from parsl.tests.configs.htex_local_alternate import fresh_config return fresh_config() +def htex_udp_config(): + """This config will force UDP""" + from parsl.tests.configs.htex_local_alternate import fresh_config + c = fresh_config() + assert len(c.executors) == 1 + + assert c.executors[0].radio_mode == "htex", "precondition: htex has a radio mode attribute, configured for htex radio" + c.executors[0].radio_mode = "udp" + + return c + + def workqueue_config(): from parsl.tests.configs.workqueue_ex import fresh_config c = fresh_config() @@ -48,7 +61,7 @@ def taskvine_config(): @pytest.mark.local -@pytest.mark.parametrize("fresh_config", [htex_config, workqueue_config, taskvine_config]) +@pytest.mark.parametrize("fresh_config", [htex_config, htex_udp_config, workqueue_config, taskvine_config]) def test_row_counts(tmpd_cwd, fresh_config): # this is imported here rather than at module level because # it isn't available in a plain parsl install, so this module diff --git a/parsl/tests/test_monitoring/test_fuzz_zmq.py b/parsl/tests/test_monitoring/test_fuzz_zmq.py index 36f048efb3..3f50385564 100644 --- a/parsl/tests/test_monitoring/test_fuzz_zmq.py +++ b/parsl/tests/test_monitoring/test_fuzz_zmq.py @@ -44,8 +44,8 @@ def test_row_counts(): # the latter is what i'm most suspicious of in my present investigation # dig out the interchange port... - hub_address = parsl.dfk().hub_address - hub_zmq_port = parsl.dfk().hub_zmq_port + hub_address = parsl.dfk().monitoring.hub_address + hub_zmq_port = parsl.dfk().monitoring.hub_zmq_port # this will send a string to a new socket connection with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: diff --git a/parsl/tests/test_mpi_apps/test_bad_mpi_config.py b/parsl/tests/test_mpi_apps/test_bad_mpi_config.py index 336bf87703..ebeb64622d 100644 --- a/parsl/tests/test_mpi_apps/test_bad_mpi_config.py +++ b/parsl/tests/test_mpi_apps/test_bad_mpi_config.py @@ -1,33 +1,48 @@ import pytest from parsl import Config -from parsl.executors import HighThroughputExecutor +from parsl.executors import MPIExecutor from parsl.launchers import AprunLauncher, SimpleLauncher, SrunLauncher from parsl.providers import SlurmProvider @pytest.mark.local -def test_bad_launcher_with_mpi_mode(): - """AssertionError if a launcher other than SimpleLauncher is supplied""" +def test_bad_launcher(): + """TypeError if a launcher other than SimpleLauncher is supplied""" for launcher in [SrunLauncher(), AprunLauncher()]: - with pytest.raises(AssertionError): + with pytest.raises(TypeError): Config(executors=[ - HighThroughputExecutor( - enable_mpi_mode=True, + MPIExecutor( provider=SlurmProvider(launcher=launcher), ) ]) @pytest.mark.local -def test_correct_launcher_with_mpi_mode(): +def test_bad_mpi_launcher(): + """ValueError if an unsupported mpi_launcher is specified""" + + with pytest.raises(ValueError): + Config(executors=[ + MPIExecutor( + mpi_launcher="bad_launcher", + provider=SlurmProvider(launcher=SimpleLauncher()), + ) + ]) + + +@pytest.mark.local +@pytest.mark.parametrize( + "mpi_launcher", + ["srun", "aprun", "mpiexec"] +) +def test_correct_launcher_with_mpi_mode(mpi_launcher: str): """Confirm that SimpleLauncher works with mpi_mode""" - config = Config(executors=[ - HighThroughputExecutor( - enable_mpi_mode=True, - provider=SlurmProvider(launcher=SimpleLauncher()), - ) - ]) - assert isinstance(config.executors[0].provider.launcher, SimpleLauncher) + executor = MPIExecutor( + mpi_launcher=mpi_launcher, + provider=SlurmProvider(launcher=SimpleLauncher()), + ) + + assert isinstance(executor.provider.launcher, SimpleLauncher) diff --git a/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py b/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py deleted file mode 100644 index e1e5c70883..0000000000 --- a/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Dict - -import pytest - -import parsl -from parsl import python_app -from parsl.tests.configs.htex_local import fresh_config - -EXECUTOR_LABEL = "MPI_TEST" - - -def local_config(): - config = fresh_config() - config.executors[0].label = EXECUTOR_LABEL - config.executors[0].max_workers_per_node = 1 - config.executors[0].enable_mpi_mode = False - return config - - -@python_app -def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: - import os - - parsl_vars = {} - for key in os.environ: - if key.startswith("PARSL_"): - parsl_vars[key] = os.environ[key] - return parsl_vars - - -@pytest.mark.local -def test_only_resource_specs_set(): - """Confirm that resource_spec env vars are set while launch prefixes are not - when enable_mpi_mode = False""" - resource_spec = { - "num_nodes": 4, - "ranks_per_node": 2, - } - - future = get_env_vars(parsl_resource_specification=resource_spec) - - result = future.result() - assert isinstance(result, Dict) - assert "PARSL_DEFAULT_PREFIX" not in result - assert "PARSL_SRUN_PREFIX" not in result - assert result["PARSL_NUM_NODES"] == str(resource_spec["num_nodes"]) - assert result["PARSL_RANKS_PER_NODE"] == str(resource_spec["ranks_per_node"]) diff --git a/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py b/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py index 6743d40eba..aff2501674 100644 --- a/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py +++ b/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py @@ -6,26 +6,34 @@ import pytest import parsl -from parsl import bash_app, python_app +from parsl import Config, bash_app, python_app +from parsl.executors import MPIExecutor from parsl.executors.high_throughput.mpi_prefix_composer import ( MissingResourceSpecification, ) -from parsl.tests.configs.htex_local import fresh_config +from parsl.launchers import SimpleLauncher +from parsl.providers import LocalProvider EXECUTOR_LABEL = "MPI_TEST" def local_setup(): - config = fresh_config() - config.executors[0].label = EXECUTOR_LABEL - config.executors[0].max_workers_per_node = 2 - config.executors[0].enable_mpi_mode = True - config.executors[0].mpi_launcher = "mpiexec" cwd = os.path.abspath(os.path.dirname(__file__)) pbs_nodefile = os.path.join(cwd, "mocks", "pbs_nodefile") - config.executors[0].provider.worker_init = f"export PBS_NODEFILE={pbs_nodefile}" + config = Config( + executors=[ + MPIExecutor( + label=EXECUTOR_LABEL, + max_workers_per_block=2, + mpi_launcher="mpiexec", + provider=LocalProvider( + worker_init=f"export PBS_NODEFILE={pbs_nodefile}", + launcher=SimpleLauncher() + ) + ) + ]) parsl.load(config) diff --git a/parsl/tests/test_mpi_apps/test_mpiex.py b/parsl/tests/test_mpi_apps/test_mpiex.py index 1b3e86e0b8..2e8a38bc68 100644 --- a/parsl/tests/test_mpi_apps/test_mpiex.py +++ b/parsl/tests/test_mpi_apps/test_mpiex.py @@ -4,7 +4,6 @@ import pytest -import parsl from parsl import Config, HighThroughputExecutor from parsl.executors.high_throughput.mpi_executor import MPIExecutor from parsl.launchers import SimpleLauncher @@ -42,9 +41,9 @@ def test_docstring(): def test_init(): """Ensure all relevant kwargs are copied over from HTEx""" - new_kwargs = {'max_workers_per_block'} - excluded_kwargs = {'available_accelerators', 'enable_mpi_mode', 'cores_per_worker', 'max_workers_per_node', - 'mem_per_worker', 'cpu_affinity', 'max_workers'} + new_kwargs = {'max_workers_per_block', 'mpi_launcher'} + excluded_kwargs = {'available_accelerators', 'cores_per_worker', 'max_workers_per_node', + 'mem_per_worker', 'cpu_affinity', 'max_workers', 'manager_selector'} # Get the kwargs from both HTEx and MPIEx htex_kwargs = set(signature(HighThroughputExecutor.__init__).parameters) diff --git a/parsl/tests/test_mpi_apps/test_resource_spec.py b/parsl/tests/test_mpi_apps/test_resource_spec.py index 99d0187ccd..f180c67d52 100644 --- a/parsl/tests/test_mpi_apps/test_resource_spec.py +++ b/parsl/tests/test_mpi_apps/test_resource_spec.py @@ -1,18 +1,20 @@ import contextlib import logging import os +import queue import typing import unittest from typing import Dict +from unittest import mock import pytest -import parsl from parsl.app.app import python_app +from parsl.executors.high_throughput.executor import HighThroughputExecutor +from parsl.executors.high_throughput.mpi_executor import MPIExecutor from parsl.executors.high_throughput.mpi_prefix_composer import ( InvalidResourceSpecification, MissingResourceSpecification, - validate_resource_spec, ) from parsl.executors.high_throughput.mpi_resource_management import ( get_nodes_in_batchjob, @@ -20,6 +22,8 @@ get_slurm_hosts_list, identify_scheduler, ) +from parsl.launchers import SimpleLauncher +from parsl.providers import LocalProvider from parsl.tests.configs.htex_local import fresh_config EXECUTOR_LABEL = "MPI_TEST" @@ -48,23 +52,6 @@ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: return parsl_vars -@pytest.mark.local -def test_resource_spec_env_vars(): - resource_spec = { - "num_nodes": 4, - "ranks_per_node": 2, - } - - assert double(5).result() == 10 - - future = get_env_vars(parsl_resource_specification=resource_spec) - - result = future.result() - assert isinstance(result, Dict) - assert result["PARSL_NUM_NODES"] == str(resource_spec["num_nodes"]) - assert result["PARSL_RANKS_PER_NODE"] == str(resource_spec["ranks_per_node"]) - - @pytest.mark.local @unittest.mock.patch("subprocess.check_output", return_value=b"c203-031\nc203-032\n") def test_slurm_mocked_mpi_fetch(subprocess_check): @@ -83,16 +70,6 @@ def add_to_path(path: os.PathLike) -> typing.Generator[None, None, None]: os.environ["PATH"] = old_path -@pytest.mark.local -@pytest.mark.skip -def test_slurm_mpi_fetch(): - logging.warning(f"Current pwd : {os.path.dirname(__file__)}") - with add_to_path(os.path.dirname(__file__)): - logging.warning(f"PATH: {os.environ['PATH']}") - nodeinfo = get_slurm_hosts_list() - logging.warning(f"Got : {nodeinfo}") - - @contextlib.contextmanager def mock_pbs_nodefile(nodefile: str = "pbs_nodefile") -> typing.Generator[None, None, None]: cwd = os.path.abspath(os.path.dirname(__file__)) @@ -122,22 +99,43 @@ def test_top_level(): @pytest.mark.local @pytest.mark.parametrize( - "resource_spec, is_mpi_enabled, exception", + "resource_spec, exception", ( - ({"num_nodes": 2, "ranks_per_node": 1}, False, None), - ({"launcher_options": "--debug_foo"}, False, None), - ({"num_nodes": 2, "BAD_OPT": 1}, False, InvalidResourceSpecification), - ({}, False, None), - ({"num_nodes": 2, "ranks_per_node": 1}, True, None), - ({"launcher_options": "--debug_foo"}, True, None), - ({"num_nodes": 2, "BAD_OPT": 1}, True, InvalidResourceSpecification), - ({}, True, MissingResourceSpecification), + + ({"num_nodes": 2, "ranks_per_node": 1}, None), + ({"launcher_options": "--debug_foo"}, None), + ({"num_nodes": 2, "BAD_OPT": 1}, InvalidResourceSpecification), + ({}, MissingResourceSpecification), ) ) -def test_resource_spec(resource_spec: Dict, is_mpi_enabled: bool, exception): +def test_mpi_resource_spec(resource_spec: Dict, exception): + """Test validation of resource_specification in MPIExecutor""" + + mpi_ex = MPIExecutor(provider=LocalProvider(launcher=SimpleLauncher())) + mpi_ex.outgoing_q = mock.Mock(spec=queue.Queue) + if exception: with pytest.raises(exception): - validate_resource_spec(resource_spec, is_mpi_enabled) + mpi_ex.validate_resource_spec(resource_spec) else: - result = validate_resource_spec(resource_spec, is_mpi_enabled) + result = mpi_ex.validate_resource_spec(resource_spec) assert result is None + + +@pytest.mark.local +@pytest.mark.parametrize( + "resource_spec", + ( + {"num_nodes": 2, "ranks_per_node": 1}, + {"launcher_options": "--debug_foo"}, + {"BAD_OPT": 1}, + ) +) +def test_mpi_resource_spec_passed_to_htex(resource_spec: dict): + """HTEX should reject every resource_spec""" + + htex = HighThroughputExecutor() + htex.outgoing_q = mock.Mock(spec=queue.Queue) + + with pytest.raises(InvalidResourceSpecification): + htex.validate_resource_spec(resource_spec) diff --git a/parsl/tests/test_providers/test_local_provider.py b/parsl/tests/test_providers/test_local_provider.py index 29907ec47d..497c13370d 100644 --- a/parsl/tests/test_providers/test_local_provider.py +++ b/parsl/tests/test_providers/test_local_provider.py @@ -11,7 +11,8 @@ import pytest -from parsl.channels import LocalChannel, SSHChannel +from parsl.channels import LocalChannel +from parsl.channels.ssh.ssh import DeprecatedSSHChannel from parsl.jobs.states import JobState from parsl.launchers import SingleNodeLauncher from parsl.providers import LocalProvider @@ -92,19 +93,24 @@ def test_ssh_channel(): # already exist, so create it here. pathlib.Path('{}/known.hosts'.format(config_dir)).touch(mode=0o600) script_dir = tempfile.mkdtemp() - p = LocalProvider(channel=SSHChannel('127.0.0.1', port=server_port, - script_dir=remote_script_dir, - host_keys_filename='{}/known.hosts'.format(config_dir), - key_filename=priv_key), - launcher=SingleNodeLauncher(debug=False)) - p.script_dir = script_dir - _run_tests(p) + channel = DeprecatedSSHChannel('127.0.0.1', port=server_port, + script_dir=remote_script_dir, + host_keys_filename='{}/known.hosts'.format(config_dir), + key_filename=priv_key) + try: + p = LocalProvider(channel=channel, + launcher=SingleNodeLauncher(debug=False)) + p.script_dir = script_dir + _run_tests(p) + finally: + channel.close() finally: _stop_sshd(sshd_thread) def _stop_sshd(sshd_thread): sshd_thread.stop() + sshd_thread.join() class SSHDThread(threading.Thread): diff --git a/parsl/tests/test_python_apps/test_context_manager.py b/parsl/tests/test_python_apps/test_context_manager.py index a314c0d362..6d3b020b16 100644 --- a/parsl/tests/test_python_apps/test_context_manager.py +++ b/parsl/tests/test_python_apps/test_context_manager.py @@ -1,7 +1,11 @@ +from concurrent.futures import Future +from threading import Event + import pytest import parsl -from parsl.dataflow.dflow import DataFlowKernel +from parsl.config import Config +from parsl.dataflow.dflow import DataFlowKernel, DataFlowKernelLoader from parsl.errors import NoDataFlowKernelError from parsl.tests.configs.local_threads import fresh_config @@ -16,6 +20,16 @@ def foo(x, stdout='foo.stdout'): return f"echo {x + 1}" +@parsl.python_app +def wait_for_event(ev: Event): + ev.wait() + + +@parsl.python_app +def raise_app(): + raise RuntimeError("raise_app deliberate failure") + + @pytest.mark.local def test_within_context_manger(tmpd_cwd): config = fresh_config() @@ -31,3 +45,84 @@ def test_within_context_manger(tmpd_cwd): with pytest.raises(NoDataFlowKernelError) as excinfo: square(2).result() assert str(excinfo.value) == "Must first load config" + + +@pytest.mark.local +def test_exit_skip(): + config = fresh_config() + config.exit_mode = "skip" + + with parsl.load(config) as dfk: + ev = Event() + fut = wait_for_event(ev) + # deliberately don't wait for this to finish, so that the context + # manager can exit + + assert parsl.dfk() is dfk, "global dfk should be left in place by skip mode" + + assert not fut.done(), "wait_for_event should not be done yet" + ev.set() + + # now we can wait for that result... + fut.result() + assert fut.done(), "wait_for_event should complete outside of context manager in 'skip' mode" + + # now cleanup the DFK that the above `with` block + # deliberately avoided doing... + dfk.cleanup() + + +# 'wait' mode has two cases to test: +# 1. that we wait when there is no exception +# 2. that we do not wait when there is an exception +@pytest.mark.local +def test_exit_wait_no_exception(): + config = fresh_config() + config.exit_mode = "wait" + + with parsl.load(config) as dfk: + fut = square(1) + # deliberately don't wait for this to finish, so that the context + # manager can exit + + assert fut.done(), "This future should be marked as done before the context manager exits" + + assert dfk.cleanup_called, "The DFK should have been cleaned up by the context manager" + assert DataFlowKernelLoader._dfk is None, "The global DFK should have been removed" + + +@pytest.mark.local +def test_exit_wait_exception(): + config = fresh_config() + config.exit_mode = "wait" + + with pytest.raises(RuntimeError): + with parsl.load(config) as dfk: + # we'll never fire this future + fut_never = Future() + + fut_raise = raise_app() + + fut_depend = square(fut_never) + + # this should cause an exception, which should cause the context + # manager to exit, without waiting for fut_depend to finish. + fut_raise.result() + + assert dfk.cleanup_called, "The DFK should have been cleaned up by the context manager" + assert DataFlowKernelLoader._dfk is None, "The global DFK should have been removed" + assert fut_raise.exception() is not None, "fut_raise should contain an exception" + assert not fut_depend.done(), "fut_depend should have been left un-done (due to dependency failure)" + + +@pytest.mark.local +def test_exit_wrong_mode(): + + with pytest.raises(Exception) as ex: + Config(exit_mode="wrongmode") + + # with typeguard 4.x this is TypeCheckError, + # with typeguard 2.x this is TypeError + # we can't instantiate TypeCheckError if we're in typeguard 2.x environment + # because it does not exist... so check name using strings. + assert ex.type.__name__ == "TypeCheckError" or ex.type.__name__ == "TypeError" diff --git a/parsl/tests/test_python_apps/test_dependencies_deep.py b/parsl/tests/test_python_apps/test_dependencies_deep.py new file mode 100644 index 0000000000..c728e1246e --- /dev/null +++ b/parsl/tests/test_python_apps/test_dependencies_deep.py @@ -0,0 +1,59 @@ +import inspect +from concurrent.futures import Future +from typing import Any, Callable, Dict + +import pytest + +import parsl +from parsl.executors.base import ParslExecutor + +# N is the number of tasks to chain +# With mid-2024 Parsl, N>140 causes Parsl to hang +N = 100 + +# MAX_STACK is the maximum Python stack depth allowed for either +# task submission to an executor or execution of a task. +# With mid-2024 Parsl, 2-3 stack entries will be used per +# recursively launched parsl task. So this should be smaller than +# 2*N, but big enough to allow regular pytest+parsl stuff to +# happen. +MAX_STACK = 50 + + +def local_config(): + return parsl.Config(executors=[ImmediateExecutor()]) + + +class ImmediateExecutor(ParslExecutor): + def start(self): + pass + + def shutdown(self): + pass + + def submit(self, func: Callable, resource_specification: Dict[str, Any], *args: Any, **kwargs: Any) -> Future: + stack_depth = len(inspect.stack()) + assert stack_depth < MAX_STACK, "tasks should not be launched deep in the Python stack" + fut: Future[None] = Future() + res = func(*args, **kwargs) + fut.set_result(res) + return fut + + +@parsl.python_app +def chain(upstream): + stack_depth = len(inspect.stack()) + assert stack_depth < MAX_STACK, "chained dependencies should not be launched deep in the Python stack" + + +@pytest.mark.local +def test_deep_dependency_stack_depth(): + + fut = Future() + here = fut + + for _ in range(N): + here = chain(here) + + fut.set_result(None) + here.result() diff --git a/parsl/tests/test_python_apps/test_inputs_default.py b/parsl/tests/test_python_apps/test_inputs_default.py new file mode 100644 index 0000000000..cf77c1a86b --- /dev/null +++ b/parsl/tests/test_python_apps/test_inputs_default.py @@ -0,0 +1,22 @@ +import pytest + +import parsl +from parsl import python_app +from parsl.executors.threads import ThreadPoolExecutor + + +def local_config(): + return parsl.Config(executors=[ThreadPoolExecutor()]) + + +@pytest.mark.local +def test_default_inputs(): + @python_app + def identity(inp): + return inp + + @python_app + def add_inputs(inputs=[identity(1), identity(2)]): + return sum(inputs) + + assert add_inputs().result() == 3 diff --git a/parsl/tests/test_serialization/test_3495_deserialize_managerlost.py b/parsl/tests/test_serialization/test_3495_deserialize_managerlost.py new file mode 100644 index 0000000000..74c0923108 --- /dev/null +++ b/parsl/tests/test_serialization/test_3495_deserialize_managerlost.py @@ -0,0 +1,47 @@ +import os +import signal + +import pytest + +import parsl +from parsl import Config, HighThroughputExecutor +from parsl.executors.high_throughput.errors import ManagerLost + + +@parsl.python_app +def get_manager_pgid(): + import os + return os.getpgid(os.getpid()) + + +@parsl.python_app +def lose_manager(): + import os + import signal + + manager_pid = os.getppid() + os.kill(manager_pid, signal.SIGSTOP) + + +@pytest.mark.local +def test_manager_lost_system_failure(tmpd_cwd): + hte = HighThroughputExecutor( + label="htex_local", + address="127.0.0.1", + max_workers_per_node=2, + cores_per_worker=1, + worker_logdir_root=str(tmpd_cwd), + heartbeat_period=1, + heartbeat_threshold=1, + ) + c = Config(executors=[hte], strategy='simple', strategy_period=0.1) + + with parsl.load(c): + manager_pgid = get_manager_pgid().result() + try: + with pytest.raises(ManagerLost): + lose_manager().result() + finally: + # Allow process to clean itself up + os.killpg(manager_pgid, signal.SIGCONT) + os.killpg(manager_pgid, signal.SIGTERM) diff --git a/parsl/tests/test_staging/test_file.py b/parsl/tests/test_staging/test_file.py index 4b57884a93..d7897da14a 100644 --- a/parsl/tests/test_staging/test_file.py +++ b/parsl/tests/test_staging/test_file.py @@ -22,11 +22,11 @@ def test_files(): @pytest.mark.local -def test_open(): - with open('test-open.txt', 'w') as tfile: - tfile.write('Hello') +def test_open(tmpd_cwd): + fpath = tmpd_cwd / 'test-open.txt' + fpath.write_text('Hello') - pfile = File('test-open.txt') + pfile = File(fpath) - with open(str(pfile), 'r') as opfile: - assert (opfile.readlines()[0] == 'Hello') + with open(pfile) as opfile: + assert (opfile.read() == 'Hello') diff --git a/requirements.txt b/requirements.txt index e89202942e..c60517655f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,6 @@ globus-sdk dill tblib requests -paramiko psutil>=5.5.1 setproctitle filelock>=3.13,<4 diff --git a/setup.py b/setup.py index dae3e64ca4..4934d01e5d 100755 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'flux': ['pyyaml', 'cffi', 'jsonschema'], 'proxystore': ['proxystore'], 'radical-pilot': ['radical.pilot==1.60', 'radical.utils==1.60'], + 'ssh': ['paramiko'], # Disabling psi-j since github direct links are not allowed by pypi # 'psij': ['psi-j-parsl@git+https://github.com/ExaWorks/psi-j-parsl'] } @@ -56,6 +57,7 @@ python_requires=">=3.8.0", install_requires=install_requires, scripts = ['parsl/executors/high_throughput/process_worker_pool.py', + 'parsl/executors/high_throughput/interchange.py', 'parsl/executors/workqueue/exec_parsl_function.py', 'parsl/executors/workqueue/parsl_coprocess.py', ], diff --git a/test-requirements.txt b/test-requirements.txt index c735de8d5c..acd670b5e9 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,12 +1,14 @@ flake8==6.1.0 ipyparallel pandas +paramiko pytest>=7.4.0,<8 pytest-cov pytest-random-order nbsphinx sphinx_rtd_theme mypy==1.5.1 +types-mock types-python-dateutil types-requests types-paramiko