diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cfe3a5f739..069f8fdb71 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: main-test-suite: strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.11"] runs-on: ubuntu-20.04 timeout-minutes: 30 diff --git a/Makefile b/Makefile index 46308a50be..15da1d1f2c 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ clean_coverage: .PHONY: mypy mypy: ## run mypy checks - MYPYPATH=$(CWD)/mypy-stubs mypy parsl/ + PYTHONPATH=$(CWD)/mypy-plugins:$(PYTHONPATH) MYPYPATH=$(CWD)/mypy-stubs mypy --no-incremental parsl/ --show-traceback .PHONY: local_thread_test local_thread_test: ## run all tests with local_thread config diff --git a/mypy-plugins/parsl_mypy.py b/mypy-plugins/parsl_mypy.py new file mode 100644 index 0000000000..a88463132d --- /dev/null +++ b/mypy-plugins/parsl_mypy.py @@ -0,0 +1,56 @@ +from mypy.plugin import FunctionContext, Plugin +from mypy.types import Type +import mypy.nodes as nodes + +def plugin(v): + return ParslMypyPlugin + +class ParslMypyPlugin(Plugin): + def get_type_analyze_hook(self, t): + # print("BENC: gtah t={}".format(t)) + return None + + def get_function_hook(self, f): + if f == "parsl.app.app.python_appXXXX": + return python_app_function_hook + else: + return None + +def python_app_function_hook(ctx: FunctionContext) -> Type: + print("inside python_app function_hook") + print("ctx = {}".format(ctx)) + + # if python_app is being called with a function parameter (rather than + # None, the default) then the return type of the python_app decorator + # is a variation (with a Future added on the type of the decorated + # function...) + + if ctx.callee_arg_names[0] == "function": # will this always be at position 0? probably fragile to assume so, but this code does make that assumption + print(f"python_app called with a function supplied: {ctx.args[0]}") + function_node = ctx.args[0][0] + print(f"function node repr is {repr(function_node)} with type {type(function_node)}") + + # return the type of function_node - actually it needs modifying to have the Future wrapper added.... + if isinstance(function_node, nodes.TempNode): + print(f"temporary node has type {function_node.type}") + print(f"Python type of tempnode.type is {type(function_node.type)}") + print(ctx.api) + # return_type = ctx.api.named_type_or_none("concurrent.futures.Future", [function_node.type.ret_type]) + # return_type = ctx.api.named_generic_type("concurrent.futures.Future", [function_node.type.ret_type]) + # return_type = ctx.api.named_generic_type("builtins.list", [function_node.type.ret_type]) + return_type = function_node.type.ret_type + # return_type = ctx.default_return_type + print(f"plugin chosen return type is {return_type}") + return function_node.type.copy_modified(ret_type=return_type) + else: + print("function node is not specified as something this plugin understands") + return_type = ctx.default_return_type + return return_type + else: + print("python_app called without a function supplied") + # TODO: this should return a type that is aligned with the implementation: + # it's the type of the decorator, assuming that it will definitely be given + # a function this time? or something along those lines... + + print("will return ctx.default_return_type = {}".format(ctx.default_return_type)) + return ctx.default_return_type diff --git a/mypy.ini b/mypy.ini index d412ee55f2..4c8c5c71e7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,11 +1,11 @@ [mypy] -plugins = sqlmypy +plugins = sqlmypy, parsl_mypy # globally disabled error codes: # str-bytes-safe warns that a byte string is formatted into a string. # which is commonly done with manager IDs in the parsl # codebase. -disable_error_code = str-bytes-safe +# disable_error_code = str-bytes-safe enable_error_code = ignore-without-code no_implicit_reexport = True warn_redundant_casts = True @@ -15,6 +15,9 @@ no_implicit_optional = True strict_equality = True warn_unused_ignores = True +# there are some exceptions to this even in the more-strongly-typed sections +warn_unreachable = True + [mypy-non_existent.*] ignore_missing_imports = True @@ -61,18 +64,100 @@ disallow_any_expr = True disallow_any_decorated = True [mypy-parsl.providers.base.*] +disallow_untyped_defs = True disallow_untyped_decorators = True check_untyped_defs = True disallow_subclassing_any = True -warn_unreachable = True + + +# modules to be checked mostly more strongly than default: + +[mypy-parsl.dataflow.dflow.*] +disallow_untyped_defs = True + +warn_unreachable = False +# this is because of some tangle of types +# to do with channel script dirs: script dir +# can be none at the start before initialisation, +# but must be set later on, and I can't +# represent that session type in mypy. +# (or can I?!) + +[mypy-parsl.dataflow.flow_control.*] +disallow_untyped_defs = True +disallow_any_decorated = True + +[mypy-parsl.dataflow.strategy.*] +disallow_untyped_defs = True +disallow_any_decorated = True + +[mypy-parsl.dataflow.job_status_poller.*] +disallow_untyped_defs = True + +# merge of #1877 introduced stuff that violates this so disabling pending perhaps further investigation +# disallow_any_expr = True + +disallow_any_decorated = True + +[mypy-parsl.config.*] +disallow_untyped_defs = True +# Any has to be allowed because TaskRecord now forms part of the type signature of config, +# and task record has Any from the type of tasks args +#disallow_any_expr = True +#disallow_any_decorated = True + +[mypy-parsl.channels.base.*] +disallow_untyped_defs = True +disallow_any_expr = True + +[mypy-parsl.channels.ssh.*] +disallow_untyped_defs = True + +[mypy-parsl.launchers.*] +disallow_untyped_defs = True +disallow_any_decorated = True + + + +[mypy-parsl.executors.base.*] +disallow_untyped_defs = True +disallow_any_expr = True +[mypy-parsl.serialize.*] disallow_untyped_defs = True -[mypy-parsl.executors.high_throughput.interchange.*] -check_untyped_defs = True + +# modules to be checked more weakly than default: + +[mypy-parsl.executors.flux.*] +ignore_errors = True [mypy-parsl.executors.extreme_scale.*] ignore_errors = True +[mypy-parsl.providers.aws.*] +check_untyped_defs = False + +[mypy-parsl.providers.pbspro.pbspro.*] +check_untyped_defs = False + +[mypy-parsl.providers.lsf.lsf.*] +check_untyped_defs = False + +[mypy-parsl.providers.torque.torque.*] +check_untyped_defs = False + +[mypy-parsl.providers.grid_engine.grid_engine.*] +check_untyped_defs = False + +[mypy-parsl.providers.googlecloud.*] +check_untyped_defs = False + +[mypy-parsl.monitoring.db_manager.*] +check_untyped_defs = False + +[mypy-parsl.executors.high_throughput.interchange.*] +check_untyped_defs = True + [mypy-parsl.executors.workqueue.*] check_untyped_defs = True @@ -86,7 +171,6 @@ ignore_errors = True disallow_untyped_decorators = True check_untyped_defs = True disallow_subclassing_any = True -warn_unreachable = True disallow_untyped_defs = True # visualization typechecks much less well than the rest of monitoring, @@ -98,9 +182,10 @@ ignore_errors = True ignore_missing_imports = True [mypy-parsl.utils] -warn_unreachable = True disallow_untyped_defs = True +# imports from elsewhere that there are no stubs for: + [mypy-flask_sqlalchemy.*] ignore_missing_imports = True @@ -164,22 +249,6 @@ ignore_missing_imports = True [mypy-zmq.*] ignore_missing_imports = True -[mypy-mpi4py.*] -ignore_missing_imports = True - -[mypy-flask.*] -ignore_missing_imports = True - -# this is an internal undocumentated package -# of multiprocessing - trying to get Event -# to typecheck in monitoring, but it's not -# a top level class as far as mypy is concerned. -# but... when commented out seems ok? -# so lets see when happens when I try to merge -# in clean CI -#[mypy-multiprocessing.synchronization.*] -#ignore_missing_imports = True - [mypy-pandas.*] ignore_missing_imports = True diff --git a/parsl/app/app.py b/parsl/app/app.py index dd7282cd5f..5d65241426 100644 --- a/parsl/app/app.py +++ b/parsl/app/app.py @@ -11,6 +11,17 @@ from parsl.dataflow.dflow import DataFlowKernel +from typing import TYPE_CHECKING +from typing import Callable + +if TYPE_CHECKING: + from typing import Dict + from typing import Any + +from parsl.dataflow.futures import AppFuture +from concurrent.futures import Future + + logger = logging.getLogger(__name__) @@ -22,7 +33,12 @@ class AppBase(metaclass=ABCMeta): """ - def __init__(self, func, data_flow_kernel=None, executors='all', cache=False, ignore_for_cache=None): + @typeguard.typechecked + def __init__(self, func: Callable, + data_flow_kernel: Optional[DataFlowKernel] = None, + executors: Union[List[str], Literal['all']] = 'all', + cache: bool = False, + ignore_for_cache=None) -> None: """Construct the App object. Args: @@ -45,13 +61,15 @@ def __init__(self, func, data_flow_kernel=None, executors='all', cache=False, ig self.executors = executors self.cache = cache self.ignore_for_cache = ignore_for_cache - if not (isinstance(executors, list) or isinstance(executors, str)): - logger.error("App {} specifies invalid executor option, expects string or list".format( - func.__name__)) + + # unreachable if properly typechecked + # if not (isinstance(executors, list) or isinstance(executors, str)): + # logger.error("App {} specifies invalid executor option, expects string or list".format( + # func.__name__)) params = signature(func).parameters - self.kwargs = {} + self.kwargs = {} # type: Dict[str, Any] if 'stdout' in params: self.kwargs['stdout'] = params['stdout'].default if 'stderr' in params: @@ -64,7 +82,7 @@ def __init__(self, func, data_flow_kernel=None, executors='all', cache=False, ig self.inputs = params['inputs'].default if 'inputs' in params else [] @abstractmethod - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> AppFuture: pass @@ -74,7 +92,8 @@ def python_app(function=None, cache: bool = False, executors: Union[List[str], Literal['all']] = 'all', ignore_for_cache: Optional[List[str]] = None, - join: bool = False): + join: bool = False, + do_not_use_F: Optional[Future] = None) -> Callable: """Decorator function for making python apps. Parameters @@ -113,7 +132,7 @@ def join_app(function=None, data_flow_kernel: Optional[DataFlowKernel] = None, cache: bool = False, executors: Union[List[str], Literal['all']] = 'all', - ignore_for_cache: Optional[List[str]] = None): + ignore_for_cache: Optional[List[str]] = None) -> Callable: """Decorator function for making join apps Parameters @@ -150,7 +169,7 @@ def bash_app(function=None, data_flow_kernel: Optional[DataFlowKernel] = None, cache: bool = False, executors: Union[List[str], Literal['all']] = 'all', - ignore_for_cache: Optional[List[str]] = None): + ignore_for_cache: Optional[List[str]] = None) -> Callable: """Decorator function for making bash apps. Parameters diff --git a/parsl/app/bash.py b/parsl/app/bash.py index 5e0ca4237e..e8690a0cc2 100644 --- a/parsl/app/bash.py +++ b/parsl/app/bash.py @@ -3,20 +3,30 @@ from inspect import signature, Parameter import logging +# for typing +from typing import Any, Callable, Dict, List, Optional, Union + +from typing_extensions import Literal + +from parsl.dataflow.futures import AppFuture from parsl.app.errors import wrap_error from parsl.app.app import AppBase from parsl.dataflow.dflow import DataFlowKernelLoader +from parsl.dataflow.dflow import DataFlowKernel # only for mypy + logger = logging.getLogger(__name__) -def remote_side_bash_executor(func, *args, **kwargs): +def remote_side_bash_executor(func: Callable[..., str], *args, **kwargs) -> int: """Executes the supplied function with *args and **kwargs to get a command-line to run, and then run that command-line using bash. """ import os import subprocess + from typing import List, cast import parsl.app.errors as pe + from parsl.data_provider.files import File from parsl.utils import get_std_fname_mode if hasattr(func, '__name__'): @@ -88,7 +98,8 @@ def open_std_fd(fdname): # TODO : Add support for globs here missing = [] - for outputfile in kwargs.get('outputs', []): + outputs = cast(List[File], kwargs.get('outputs', [])) + for outputfile in outputs: fpath = outputfile.filepath if not os.path.exists(fpath): @@ -102,7 +113,10 @@ def open_std_fd(fdname): class BashApp(AppBase): - def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ignore_for_cache=None): + def __init__(self, func: Callable[..., str], data_flow_kernel: Optional[DataFlowKernel] = None, + cache: bool = False, + executors: Union[List[str], Literal['all']] = 'all', + ignore_for_cache: Optional[List[str]] = None) -> None: super().__init__(func, data_flow_kernel=data_flow_kernel, executors=executors, cache=cache, ignore_for_cache=ignore_for_cache) self.kwargs = {} @@ -120,10 +134,16 @@ def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ig # this is done to avoid passing a function type in the args which parsl.serializer # doesn't support remote_fn = partial(update_wrapper(remote_side_bash_executor, self.func), self.func) - remote_fn.__name__ = self.func.__name__ + + # parsl/app/bash.py:145: error: "partial[Any]" has no attribute "__name__" + # but... other parts of the code are relying on getting the __name__ + # of (?) an arbitrary Callable too (which is why we're setting the __name__ + # at all) + remote_fn.__name__ = self.func.__name__ # type: ignore[attr-defined] + self.wrapped_remote_function = wrap_error(remote_fn) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> AppFuture: """Handle the call to a Bash app. Args: @@ -136,7 +156,7 @@ def __call__(self, *args, **kwargs): App_fut """ - invocation_kwargs = {} + invocation_kwargs = {} # type: Dict[str, Any] invocation_kwargs.update(self.kwargs) invocation_kwargs.update(kwargs) diff --git a/parsl/app/errors.py b/parsl/app/errors.py index 6e3562ab23..d971488168 100644 --- a/parsl/app/errors.py +++ b/parsl/app/errors.py @@ -1,6 +1,6 @@ """Exceptions raised by Apps.""" from functools import wraps -from typing import Callable, List, Union, Any, TypeVar, Optional +from typing import List, Any, Optional from types import TracebackType import logging from tblib import Traceback @@ -10,6 +10,14 @@ from parsl.data_provider.files import File from parsl.errors import ParslError + +# vs PR 1846: benc-mypy imports File from the data_provider module +# more directly, potentially to avoid import loops from trying to +# import the top level "parsl" here. +# from parsl import File + +# vs PR 1846: see TypeVar playing in response to PR 1846 TODO + logger = logging.getLogger(__name__) @@ -66,7 +74,10 @@ class MissingOutputs(ParslError): outputs(List of strings/files..) """ - def __init__(self, reason: str, outputs: List[Union[str, File]]) -> None: + # vs PR 1846: I use List[File] for outputs; this PR uses a union of str or File + # That might be because I've done other tidyup work regarding strings and files? + + def __init__(self, reason: str, outputs: List[File]) -> None: super().__init__(reason, outputs) self.reason = reason self.outputs = outputs @@ -132,26 +143,17 @@ def get_exception(self) -> BaseException: return v -R = TypeVar('R') +# vs PR 1846: PR 1846 makes wrap_error go from any callable to any callable +# and typechecks without casts. -# There appears to be no solution to typing this without a mypy plugin. -# The reason is because wrap_error maps a Callable[[X...], R] to a Callable[[X...], Union[R, R2]]. -# However, there is no provision in Python typing for pattern matching all possible types of -# callable arguments. This is because Callable[] is, in the infinite wisdom of the typing module, -# only used for callbacks: "There is no syntax to indicate optional or keyword arguments; such -# function types are rarely used as callback types.". -# The alternative supported by the typing module, of saying Callable[..., R] -> -# Callable[..., Union[R, R2]] results in no pattern matching between the first and second -# ellipsis. -# Yet another bogus solution that was here previously would simply define wrap_error as -# wrap_error(T) -> T, where T was a custom TypeVar. This obviously missed the fact that -# the returned function had its return signature modified. -# Ultimately, the best choice appears to be Callable[..., R] -> Callable[..., Union[R, ?Exception]], -# since it results in the correct type specification for the return value(s) while treating the -# arguments as Any. +# see https://github.com/dry-python/returns/blob/92eda5574a8e41f4f5af4dd29887337886301ee3/returns/contrib/mypy/decorator_plugin.py +# for a mypy plugin to do this in a hacky way +# and this issue for more info on typing decorators: +# https://github.com/python/mypy/issues/3157 -def wrap_error(func: Callable[..., R]) -> Callable[..., Union[R, RemoteExceptionWrapper]]: +# def wrap_error(func: Callable[..., R]) -> Callable[..., Union[R, RemoteExceptionWrapper]]: +def wrap_error(func): @wraps(func) def wrapper(*args: object, **kwargs: object) -> Any: import sys diff --git a/parsl/app/python.py b/parsl/app/python.py index 26eb8167bf..df682efc92 100644 --- a/parsl/app/python.py +++ b/parsl/app/python.py @@ -8,6 +8,7 @@ from parsl.app.app import AppBase from parsl.app.errors import wrap_error from parsl.dataflow.dflow import DataFlowKernelLoader +from parsl.dataflow.futures import AppFuture logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ def inject_exception(thread): class PythonApp(AppBase): """Extends AppBase to cover the Python App.""" - def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ignore_for_cache=[], join=False): + def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ignore_for_cache=[], join: bool = False) -> None: super().__init__( wrap_error(func), data_flow_kernel=data_flow_kernel, @@ -48,7 +49,7 @@ def __init__(self, func, data_flow_kernel=None, cache=False, executors='all', ig ) self.join = join - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> AppFuture: """This is where the call to a python app is handled. Args: diff --git a/parsl/channels/errors.py b/parsl/channels/errors.py index 5bf5a0b280..2b004ecb5b 100644 --- a/parsl/channels/errors.py +++ b/parsl/channels/errors.py @@ -3,15 +3,27 @@ from parsl.errors import ParslError from typing import Optional +# vs PR 1846 +# there are differences in style between calling super.__init__ with +# parameters, or at all +# I think these are only stylistic + +# there are some semantic differences - eg removal of exceptions from +# channel error base and adding into only the subclasses which have +# exceptions (can't remember what motivated this specifically) + class ChannelError(ParslError): """ Base class for all exceptions Only to be invoked when only a more specific error is not available. + + vs PR 1846: differs in calling of super.__init__ and I've removed + the Exception parameter """ - def __init__(self, reason: str, e: Exception, hostname: str) -> None: + def __init__(self, reason: str, hostname: str) -> None: + super().__init__() self.reason = reason - self.e = e self.hostname = hostname def __repr__(self) -> str: @@ -31,9 +43,11 @@ class BadHostKeyException(ChannelError): hostname (string) ''' - def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("SSH channel could not be created since server's host keys could not be " - "verified", e, hostname) + # vs PR 1846: removal of 'e' parameter in superclass + # and stored in this class init instead + def __init__(self, e, hostname): + super().__init__("SSH channel could not be created since server's host keys could not be verified", hostname) + self.e = e class BadScriptPath(ChannelError): @@ -45,8 +59,10 @@ class BadScriptPath(ChannelError): hostname (string) ''' + # vs PR 1846, as with BadHostKeyException: remove e from superclass, store in this class def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("Inaccessible remote script dir. Specify script_dir", e, hostname) + super().__init__("Inaccessible remote script dir. Specify script_dir", hostname) + self.e = e class BadPermsScriptPath(ChannelError): @@ -58,8 +74,10 @@ class BadPermsScriptPath(ChannelError): hostname (string) ''' + "vs PR 1846, store exception locally" def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("User does not have permissions to access the script_dir", e, hostname) + super().__init__("User does not have permissions to access the script_dir", hostname) + self.e = e class FileExists(ChannelError): @@ -72,9 +90,12 @@ class FileExists(ChannelError): hostname (string) ''' + # vs PR 1846, PR 1846 uses .format instead of +filename. I previously kept the original behaviour + # but am adopting the .format style here def __init__(self, e: Exception, hostname: str, filename: Optional[str] = None) -> None: super().__init__("File name collision in channel transport phase: {}".format(filename), - e, hostname) + hostname) + self.e = e class AuthException(ChannelError): @@ -87,7 +108,8 @@ class AuthException(ChannelError): ''' def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("Authentication to remote server failed", e, hostname) + super().__init__("Authentication to remote server failed", hostname) + self.e = e class SSHException(ChannelError): @@ -100,7 +122,8 @@ class SSHException(ChannelError): ''' def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("Error connecting or establishing an SSH session", e, hostname) + super().__init__("Error connecting or establishing an SSH session", hostname) + self.e = e class FileCopyException(ChannelError): @@ -113,4 +136,5 @@ class FileCopyException(ChannelError): ''' def __init__(self, e: Exception, hostname: str) -> None: - super().__init__("File copy failed due to {0}".format(e), e, hostname) + super().__init__("File copy failed due to {0}".format(e), hostname) + self.e = e diff --git a/parsl/channels/local/local.py b/parsl/channels/local/local.py index ee2a7a5088..e912da6c12 100644 --- a/parsl/channels/local/local.py +++ b/parsl/channels/local/local.py @@ -10,6 +10,8 @@ logger = logging.getLogger(__name__) +from typing import Dict, Tuple, Optional + class LocalChannel(Channel, RepresentationMixin): ''' This is not even really a channel, since opening a local shell is not heavy @@ -32,7 +34,7 @@ def __init__(self, userhome=".", envs={}, script_dir=None): self._envs.update(envs) self.script_dir = script_dir - def execute_wait(self, cmd, walltime=None, envs={}): + def execute_wait(self, cmd: str, walltime: Optional[int] = None, envs: Dict[str, str] = {}) -> Tuple[int, Optional[str], Optional[str]]: ''' Synchronously execute a commandline string on the shell. Args: @@ -77,7 +79,7 @@ def execute_wait(self, cmd, walltime=None, envs={}): return (retcode, stdout.decode("utf-8"), stderr.decode("utf-8")) - def push_file(self, source, dest_dir): + def push_file(self, source: str, dest_dir: str) -> str: ''' If the source files dirpath is the same as dest_dir, a copy is not necessary, and nothing is done. Else a copy is made. diff --git a/parsl/channels/ssh/ssh.py b/parsl/channels/ssh/ssh.py index 02c898ad68..1b292a460c 100644 --- a/parsl/channels/ssh/ssh.py +++ b/parsl/channels/ssh/ssh.py @@ -1,6 +1,7 @@ import errno import logging import os +import typeguard import paramiko from parsl.channels.base import Channel @@ -9,11 +10,18 @@ logger = logging.getLogger(__name__) +from typing import Any, Dict, List, Tuple, Optional + class NoAuthSSHClient(paramiko.SSHClient): - def _auth(self, username, *args): - self._transport.auth_none(username) - return + def _auth(self, username: str, *args: List[Any]) -> None: + # swapped _internal variable for get_transport accessor + # method that I'm assuming without checking does the + # same thing. + transport = self.get_transport() + if transport is None: + raise RuntimeError("Expected a transport to be available") + transport.auth_none(username) class SSHChannel(Channel, RepresentationMixin): @@ -26,8 +34,18 @@ class SSHChannel(Channel, RepresentationMixin): ''' - def __init__(self, hostname, username=None, password=None, script_dir=None, envs=None, - gssapi_auth=False, skip_auth=False, port=22, key_filename=None, host_keys_filename=None): + @typeguard.typechecked + def __init__(self, + hostname: str, + username: Optional[str] = None, + password: Optional[str] = None, + script_dir: Optional[str] = None, + envs: Optional[Dict[str, str]] = None, + gssapi_auth: bool = False, + skip_auth: bool = False, + port: int = 22, + key_filename: Optional[str] = None, + host_keys_filename: Optional[str] = None): ''' Initialize a persistent connection to the remote system. We should know at this point whether ssh connectivity is possible @@ -50,29 +68,39 @@ def __init__(self, hostname, username=None, password=None, script_dir=None, envs self.username = username self.password = password self.port = port - self.script_dir = script_dir + + # if script_dir is a `str`, which it is from Channel, then can't + # assign None to it. Here and the property accessors are changed + # in benc-mypy to raise an error ratehr than return a None, + # because Channel-using code assumes that script_dir will always + # return a string and not a None. That assumption is not otherwise + # guaranteed by the type-system... + self._script_dir = None + if script_dir: + self.script_dir = script_dir + self.skip_auth = skip_auth self.gssapi_auth = gssapi_auth self.key_filename = key_filename self.host_keys_filename = host_keys_filename if self.skip_auth: - self.ssh_client = NoAuthSSHClient() + self.ssh_client: paramiko.SSHClient = NoAuthSSHClient() else: self.ssh_client = paramiko.SSHClient() self.ssh_client.load_system_host_keys(filename=host_keys_filename) self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.sftp_client = None + self.sftp_client: Optional[paramiko.SFTPClient] = None - self.envs = {} + self.envs = {} # type: Dict[str, str] if envs is not None: self.envs = envs - def _is_connected(self): + def _is_connected(self) -> bool: transport = self.ssh_client.get_transport() if self.ssh_client else None - return transport and transport.is_active() + return bool(transport and transport.is_active()) - def _connect(self): + def _connect(self) -> None: if not self._is_connected(): logger.debug(f"connecting to {self.hostname}:{self.port}") try: @@ -87,6 +115,8 @@ def _connect(self): key_filename=self.key_filename ) transport = self.ssh_client.get_transport() + if not transport: + raise RuntimeError("SSH client transport is None, despite connecting") self.sftp_client = paramiko.SFTPClient.from_transport(transport) except paramiko.BadHostKeyException as e: @@ -101,15 +131,17 @@ def _connect(self): except Exception as e: raise SSHException(e, self.hostname) - def _valid_sftp_client(self): + def _valid_sftp_client(self) -> paramiko.SFTPClient: self._connect() + if self.sftp_client is None: + raise RuntimeError("Internal consistency error: self.sftp_client should be valid but is not") return self.sftp_client - def _valid_ssh_client(self): + def _valid_ssh_client(self) -> paramiko.SSHClient: self._connect() return self.ssh_client - def prepend_envs(self, cmd, env={}): + def prepend_envs(self, cmd: str, env: Dict[str, str] = {}) -> str: env.update(self.envs) if len(env.keys()) > 0: @@ -117,7 +149,7 @@ def prepend_envs(self, cmd, env={}): return 'env {0} {1}'.format(env_vars, cmd) return cmd - def execute_wait(self, cmd, walltime=2, envs={}): + def execute_wait(self, cmd: str, walltime: int = 2, envs: Dict[str, str] = {}) -> Tuple[int, Optional[str], Optional[str]]: ''' Synchronously execute a commandline string on the shell. Args: @@ -144,7 +176,7 @@ def execute_wait(self, cmd, walltime=2, envs={}): exit_status = stdout.channel.recv_exit_status() return exit_status, stdout.read().decode("utf-8"), stderr.read().decode("utf-8") - def push_file(self, local_source, remote_dir): + def push_file(self, local_source: str, remote_dir: str) -> str: ''' Transport a local file to a directory on a remote machine Args: @@ -184,7 +216,7 @@ def push_file(self, local_source, remote_dir): return remote_dest - def pull_file(self, remote_source, local_dir): + def pull_file(self, remote_source: str, local_dir: str) -> str: ''' Transport file on the remote side to a local directory Args: @@ -217,11 +249,12 @@ def pull_file(self, remote_source, local_dir): return local_dest - def close(self): + def close(self) -> bool: if self._is_connected(): - return self.ssh_client.close() + self.ssh_client.close() + return True - def isdir(self, path): + def isdir(self, path: str) -> bool: """Return true if the path refers to an existing directory. Parameters @@ -237,7 +270,7 @@ def isdir(self, path): return result - def makedirs(self, path, mode=0o700, exist_ok=False): + def makedirs(self, path: str, mode: int = 0o700, exist_ok: bool = False) -> None: """Create a directory on the remote side. If intermediate directories do not exist, they will be created. @@ -257,7 +290,7 @@ def makedirs(self, path, mode=0o700, exist_ok=False): self.execute_wait('mkdir -p {}'.format(path)) self._valid_sftp_client().chmod(path, mode) - def abspath(self, path): + def abspath(self, path: str) -> str: """Return the absolute path on the remote side. Parameters @@ -268,9 +301,12 @@ def abspath(self, path): return self._valid_sftp_client().normalize(path) @property - def script_dir(self): - return self._script_dir + def script_dir(self) -> str: + if self._script_dir: + return self._script_dir + else: + raise RuntimeError("scriptdir was not set") @script_dir.setter - def script_dir(self, value): + def script_dir(self, value: Optional[str]) -> None: self._script_dir = value diff --git a/parsl/channels/ssh_il/ssh_il.py b/parsl/channels/ssh_il/ssh_il.py index 0905334329..4bac6712c2 100644 --- a/parsl/channels/ssh_il/ssh_il.py +++ b/parsl/channels/ssh_il/ssh_il.py @@ -67,6 +67,8 @@ def __init__(self, hostname, username=None, password=None, script_dir=None, envs ''' transport = self.ssh_client.get_transport() + if transport is None: + raise RuntimeError("Expected transport to be available") il_password = getpass.getpass('Enter {0} Logon password :'.format(hostname)) transport.auth_password(username, il_password) diff --git a/parsl/config.py b/parsl/config.py index 0f56c5b9cb..4433a45476 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -92,9 +92,12 @@ def __init__(self, monitoring: Optional[MonitoringHub] = None, usage_tracking: bool = False, initialize_logging: bool = True) -> None: + self._executors: Sequence[ParslExecutor] if executors is None: - executors = [ThreadPoolExecutor()] - self.executors = executors + self._executors = [ThreadPoolExecutor()] + else: + self._validate_executors(executors) + self._executors = executors self.app_cache = app_cache self.checkpoint_files = checkpoint_files self.checkpoint_mode = checkpoint_mode @@ -125,11 +128,9 @@ def __init__(self, def executors(self) -> Sequence[ParslExecutor]: return self._executors - @executors.setter - def executors(self, executors: Sequence[ParslExecutor]): + def _validate_executors(self, executors: Sequence[ParslExecutor]) -> None: labels = [e.label for e in executors] duplicates = [e for n, e in enumerate(labels) if e in labels[:n]] if len(duplicates) > 0: raise ConfigurationError('Executors must have unique labels ({})'.format( ', '.join(['label={}'.format(repr(d)) for d in duplicates]))) - self._executors = executors diff --git a/parsl/data_provider/data_manager.py b/parsl/data_provider/data_manager.py index e9d4cfbadb..b75f49a984 100644 --- a/parsl/data_provider/data_manager.py +++ b/parsl/data_provider/data_manager.py @@ -1,6 +1,6 @@ import logging from concurrent.futures import Future -from typing import Any, Callable, List, Optional, TYPE_CHECKING +from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING from parsl.app.futures import DataFuture from parsl.data_provider.files import File @@ -38,7 +38,7 @@ def replace_task_stage_out(self, file: File, func: Callable, executor: str) -> C """This will give staging providers the chance to wrap (or replace entirely!) the task function.""" executor_obj = self.dfk.executors[executor] if hasattr(executor_obj, "storage_access") and executor_obj.storage_access is not None: - storage_access = executor_obj.storage_access # type: List[Staging] + storage_access = executor_obj.storage_access # type: Sequence[Staging] else: storage_access = default_staging diff --git a/parsl/data_provider/globus.py b/parsl/data_provider/globus.py index 69be6925b5..4452e9e04b 100644 --- a/parsl/data_provider/globus.py +++ b/parsl/data_provider/globus.py @@ -6,7 +6,7 @@ import typeguard from functools import partial -from typing import Optional +from typing import Any, Optional from parsl.app.app import python_app from parsl.utils import RepresentationMixin from parsl.data_provider.staging import Staging @@ -59,7 +59,9 @@ class Globus: - monitoring transfers. """ - authorizer = None + authorizer = None # type: Any # otherwise inferred as type: None + + TOKEN_FILE: str @classmethod def init(cls): diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index 8af8a8ded5..3e0bc0dc86 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -13,12 +13,15 @@ import datetime from getpass import getuser from typeguard import typechecked -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import cast, Any, Callable, Dict, Iterable, Optional, Union, List, Sequence, Tuple from uuid import uuid4 from socket import gethostname from concurrent.futures import Future from functools import partial +# mostly for type checking +from parsl.executors.base import ParslExecutor, FutureWithTaskID + import parsl from parsl.app.errors import RemoteExceptionWrapper from parsl.app.futures import DataFuture @@ -35,12 +38,12 @@ from parsl.dataflow.taskrecord import TaskRecord from parsl.errors import ConfigurationError from parsl.usage_tracking.usage import UsageTracker -from parsl.executors.base import ParslExecutor from parsl.executors.status_handling import BlockProviderExecutor from parsl.executors.threads import ThreadPoolExecutor from parsl.monitoring import MonitoringHub from parsl.process_loggers import wrap_with_logs from parsl.providers.base import ExecutionProvider, JobStatus, JobState +from parsl.providers.base import Channeled, MultiChanneled from parsl.utils import get_version, get_std_fname_mode, get_all_checkpoints, Timer from parsl.monitoring.message_type import MessageType @@ -208,14 +211,29 @@ def _send_task_log_info(self, task_record: TaskRecord) -> None: task_log_info = self._create_task_log_info(task_record) self.monitoring.send(MessageType.TASK_INFO, task_log_info) - def _create_task_log_info(self, task_record): + def _create_task_log_info(self, task_record: TaskRecord) -> Dict[str, Any]: """ Create the dictionary that will be included in the log. """ - info_to_monitor = ['func_name', 'memoize', 'hashsum', 'fail_count', 'fail_cost', 'status', - 'id', 'time_invoked', 'try_time_launched', 'time_returned', 'try_time_returned', 'executor'] - task_log_info = {"task_" + k: task_record[k] for k in info_to_monitor} + # because self.tasks[task_id] is now a TaskRecord not a Dict[str,...], type checking + # can't do enough type checking if just iterating over this list of keys to copy + # and the assignments need to be written out explicitly. + + task_log_info = {} # type: Dict[str, Any] + + task_log_info["task_func_name"] = task_record['func_name'] + task_log_info["task_memoize"] = task_record['memoize'] + task_log_info["task_hashsum"] = task_record['hashsum'] + task_log_info["task_fail_count"] = task_record['fail_count'] + task_log_info["task_fail_cost"] = task_record['fail_cost'] + task_log_info["task_status"] = task_record['status'] + task_log_info["task_id"] = task_record['id'] + task_log_info["task_time_invoked"] = task_record['time_invoked'] + task_log_info["task_try_time_launched"] = task_record['try_time_launched'] + task_log_info["task_time_returned"] = task_record['time_returned'] + task_log_info["task_try_time_returned"] = task_record['try_time_returned'] + task_log_info["task_executor"] = task_record['executor'] task_log_info['run_id'] = self.run_id task_log_info['try_id'] = task_record['try_id'] task_log_info['timestamp'] = datetime.datetime.now() @@ -243,9 +261,8 @@ def _create_task_log_info(self, task_record): task_log_info['task_stderr'] = stderr_name task_log_info['task_fail_history'] = ",".join(task_record['fail_history']) task_log_info['task_depends'] = None - if task_record['depends'] is not None: - task_log_info['task_depends'] = ",".join([str(t.tid) for t in task_record['depends'] - if isinstance(t, AppFuture) or isinstance(t, DataFuture)]) + task_log_info['task_depends'] = ",".join([str(t.tid) for t in task_record['depends'] if isinstance(t, AppFuture) or isinstance(t, DataFuture)]) + task_log_info['task_joins'] = None if isinstance(task_record['joins'], list): @@ -262,9 +279,8 @@ def _count_deps(self, depends: Sequence[Future]) -> int: """ count = 0 for dep in depends: - if isinstance(dep, Future): - if not dep.done(): - count += 1 + if not dep.done(): + count += 1 return count @@ -555,7 +571,7 @@ def update_task_state(self, task_record: TaskRecord, new_state: States) -> None: """ with self.task_state_counts_lock: - if 'status' in task_record: + if hasattr(task_record, 'status'): self.task_state_counts[task_record['status']] -= 1 self.task_state_counts[new_state] += 1 task_record['status'] = new_state @@ -716,7 +732,7 @@ def launch_task(self, task_record: TaskRecord) -> Future: self._send_task_log_info(task_record) - if hasattr(exec_fu, "parsl_executor_task_id"): + if isinstance(exec_fu, FutureWithTaskID): logger.info(f"Parsl task {task_id} try {try_id} launched on executor {executor.label} with executor id {exec_fu.parsl_executor_task_id}") else: logger.info(f"Parsl task {task_id} try {try_id} launched on executor {executor.label}") @@ -734,6 +750,9 @@ def _add_input_deps(self, executor: str, args: Sequence[Any], kwargs: Dict[str, - executor (str) : executor where the app is going to be launched - args (List) : Positional args to app function - kwargs (Dict) : Kwargs to app function + - func : the function that will be invoked + + Returns: args, kwargs, (replacement, wrapping) function """ # Return if the task is a data management task, rather than doing @@ -820,7 +839,9 @@ def check_dep(d: Any) -> None: return depends - def _unwrap_futures(self, args, kwargs): + def _unwrap_futures(self, + args: Sequence[Any], + kwargs: Dict[str, Any]) -> Tuple[Sequence[Any], Dict[str, Any], Sequence[Tuple[Exception, str]]]: """This function should be called when all dependencies have completed. It will rewrite the arguments for that task, replacing each Future @@ -838,6 +859,10 @@ def _unwrap_futures(self, args, kwargs): a rewritten kwargs dict pairs of exceptions, task ids from any Futures which stored exceptions rather than results. + + TODO: mypy note: we take a *tuple* of args but return a *list* of args. + That's an (unintentional?) change of type of arg structure which leads me + to try to represent the args in TaskRecord as a Sequence """ dep_failures = [] @@ -849,7 +874,9 @@ def _unwrap_futures(self, args, kwargs): new_args.extend([dep.result()]) except Exception as e: if hasattr(dep, 'task_def'): - tid = dep.task_def['id'] + # this cast is because hasattr facts don't propagate into if statements - replace with a protocol? + d_tmp = cast(Any, dep) + tid = d_tmp.task_def['id'] else: tid = None dep_failures.extend([(e, tid)]) @@ -864,7 +891,8 @@ def _unwrap_futures(self, args, kwargs): kwargs[key] = dep.result() except Exception as e: if hasattr(dep, 'task_def'): - tid = dep.task_def['id'] + d_tmp = cast(Any, dep) + tid = d_tmp.task_def['id'] else: tid = None dep_failures.extend([(e, tid)]) @@ -878,7 +906,8 @@ def _unwrap_futures(self, args, kwargs): new_inputs.extend([dep.result()]) except Exception as e: if hasattr(dep, 'task_def'): - tid = dep.task_def['id'] + d_tmp = cast(Any, dep) + tid = d_tmp.task_def['id'] else: tid = None dep_failures.extend([(e, tid)]) @@ -1095,23 +1124,27 @@ def _create_remote_dirs_over_channel(self, provider: ExecutionProvider, channel: channel.makedirs(channel.script_dir, exist_ok=True) - def add_executors(self, executors): + def add_executors(self, executors: Sequence[ParslExecutor]) -> None: for executor in executors: executor.run_id = self.run_id executor.run_dir = self.run_dir executor.hub_address = self.hub_address executor.hub_port = self.hub_interchange_port - if hasattr(executor, 'provider'): + if executor.provider is not None: # could be a protocol? if hasattr(executor.provider, 'script_dir'): executor.provider.script_dir = os.path.join(self.run_dir, 'submit_scripts') os.makedirs(executor.provider.script_dir, exist_ok=True) - if hasattr(executor.provider, 'channels'): + if isinstance(executor.provider, MultiChanneled): logger.debug("Creating script_dir across multiple channels") for channel in executor.provider.channels: self._create_remote_dirs_over_channel(executor.provider, channel) - else: + elif isinstance(executor.provider, Channeled): self._create_remote_dirs_over_channel(executor.provider, executor.provider.channel) + else: + raise ValueError(("Assuming executor.provider has channel(s) based on it " + "having provider/script_dir, but actually it isn't a " + "(Multi)Channeled instance. provider = {}").format(executor.provider)) self.executors[executor.label] = executor block_ids = executor.start() @@ -1194,6 +1227,18 @@ def cleanup(self) -> None: if not executor.bad_state_is_set: if isinstance(executor, BlockProviderExecutor): logger.info(f"Scaling in executor {executor.label}") + + # this block catches the complicated type situation + # that an executor is managed but has no provider. + # ideally the various ways of indicating that an + # executor does provider scaling would be rationalised + # to make this neater. + if executor.provider is None: + logger.error("There is no provider to perform scaling in") + continue + + # what's the proof that there's a provider here? + # some claim that "managed" implies that there is a provider? job_ids = executor.provider.resources.keys() block_ids = executor.scale_in(len(job_ids)) if self.monitoring and block_ids: @@ -1236,7 +1281,7 @@ def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> str: Kwargs: - tasks (List of task records) : List of task ids to checkpoint. Default=None - if set to None, we iterate over all tasks held by the DFK. + if set to None or [], we iterate over all tasks held by the DFK. .. note:: Checkpointing only works if memoization is enabled @@ -1248,7 +1293,7 @@ def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> str: """ with self.checkpoint_lock: if tasks: - checkpoint_queue = tasks + checkpoint_queue = tasks # type: Iterable[TaskRecord] else: checkpoint_queue = self.checkpointable_tasks self.checkpointable_tasks = [] @@ -1283,7 +1328,7 @@ def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> str: continue t = {'hash': hashsum, 'exception': None, - 'result': None} + 'result': None} # type: Dict[str, Any] t['result'] = app_fu.result() @@ -1355,7 +1400,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A len(memo_lookup_table.keys()))) return memo_lookup_table - def load_checkpoints(self, checkpointDirs): + @typeguard.typechecked + def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> 'Dict[str, Future]': + # typeguard 2.10.1 cannot cope with Future[Any], giving a type-not-subscriptable error. + # Future with no subscript is probably equivalent though? I wanted the Any in there as + # an explicit note that I had the possibility of tighter typing but wasn't using it. + # def load_checkpoints(self, checkpointDirs: Optional[List[str]]) -> 'Dict[str, Future[Any]]': """Load checkpoints from the checkpoint files into a dictionary. The results are used to pre-populate the memoizer's lookup_table @@ -1369,14 +1419,11 @@ def load_checkpoints(self, checkpointDirs): """ self.memo_lookup_table = None - if not checkpointDirs: + if checkpointDirs: + return self._load_checkpoints(checkpointDirs) + else: return {} - if type(checkpointDirs) is not list: - raise BadCheckpoint("checkpointDirs expects a list of checkpoints") - - return self._load_checkpoints(checkpointDirs) - @staticmethod def _log_std_streams(task_record: TaskRecord) -> None: if task_record['app_fu'].stdout is not None: @@ -1413,12 +1460,17 @@ def load(cls, config: Optional[Config] = None) -> DataFlowKernel: if cls._dfk is not None: raise RuntimeError('Config has already been loaded') + # using new_dfk as an intermediate variable allows it to have + # the type DataFlowKernel, which is stricter than the type of + # cls._dfk : Optional[DataFlowKernel] and so we can return the + # correct type. if config is None: - cls._dfk = DataFlowKernel(Config()) + new_dfk = DataFlowKernel(Config()) else: - cls._dfk = DataFlowKernel(config) + new_dfk = DataFlowKernel(config) - return cls._dfk + cls._dfk = new_dfk + return new_dfk @classmethod def wait_for_current_tasks(cls) -> None: diff --git a/parsl/dataflow/flow_control.py b/parsl/dataflow/flow_control.py index b1e0cafcb6..7bf7cecca9 100644 --- a/parsl/dataflow/flow_control.py +++ b/parsl/dataflow/flow_control.py @@ -1,11 +1,15 @@ +from __future__ import annotations import logging -from typing import Sequence - from parsl.executors.base import ParslExecutor from parsl.dataflow.job_status_poller import JobStatusPoller from parsl.utils import Timer +from typing import Sequence +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from parsl.dataflow.dflow import DataFlowKernel + logger = logging.getLogger(__name__) @@ -14,7 +18,7 @@ class FlowControl(Timer): to give the block scaling strategy a chance to execute. """ - def __init__(self, dfk): + def __init__(self, dfk: "DataFlowKernel") -> None: """Initialize the flowcontrol object. We start the timer thread here diff --git a/parsl/dataflow/job_status_poller.py b/parsl/dataflow/job_status_poller.py index 20d03af475..183f231b15 100644 --- a/parsl/dataflow/job_status_poller.py +++ b/parsl/dataflow/job_status_poller.py @@ -2,7 +2,7 @@ import parsl # noqa F401 (used in string type annotation) import time import zmq -from typing import Dict, Sequence +from typing import Any, cast, Dict, Sequence, Optional from typing import List # noqa F401 (used in type annotation) from parsl.dataflow.executor_status import ExecutorStatus @@ -26,6 +26,12 @@ def __init__(self, executor: ParslExecutor, dfk: "parsl.dataflow.dflow.DataFlowK # Create a ZMQ channel to send poll status to monitoring self.monitoring_enabled = False + + # mypy 0.790 cannot determine the type for self._dfk.monitoring + # even though it can determine that _dfk is a DFK. Perhaps because of + # the same cyclic import that makes DataFlowKernel need to be quoted + # in the __init__ type signature? + # So explicitly ignore this type problem. if self._dfk.monitoring is not None: self.monitoring_enabled = True hub_address = self._dfk.hub_address @@ -53,7 +59,7 @@ def poll(self, now: float) -> None: if delta_status: self.send_monitoring_info(delta_status) - def send_monitoring_info(self, status: Dict): + def send_monitoring_info(self, status: Dict) -> None: # Send monitoring info for HTEX when monitoring enabled if self.monitoring_enabled: msg = self._executor.create_monitoring_info(status) @@ -72,11 +78,16 @@ def status(self) -> Dict[str, JobStatus]: def executor(self) -> ParslExecutor: return self._executor - def scale_in(self, n, force=True, max_idletime=None): + def scale_in(self, n: int, force: bool = True, max_idletime: Optional[float] = None) -> List[str]: if force and not max_idletime: block_ids = self._executor.scale_in(n) else: - block_ids = self._executor.scale_in(n, force=force, max_idletime=max_idletime) + # this cast is because ParslExecutor.scale_in doesn't have force or max_idletime parameters + # so we just hope that the actual executor happens to have them. + # see some notes in ParslExecutor about making the status handling superclass into a + # class that holds all the scaling methods, so that everything can be specialised + # to work on those. + block_ids = cast(Any, self._executor).scale_in(n, force=force, max_idletime=max_idletime) if block_ids is not None: new_status = {} for block_id in block_ids: @@ -85,14 +96,20 @@ def scale_in(self, n, force=True, max_idletime=None): self.send_monitoring_info(new_status) return block_ids - def scale_out(self, n): + def scale_out(self, n: int) -> List[str]: + logger.debug("BENC: in task status scale out") block_ids = self._executor.scale_out(n) - if block_ids is not None: - new_status = {} - for block_id in block_ids: - new_status[block_id] = JobStatus(JobState.PENDING) - self.send_monitoring_info(new_status) - self._status.update(new_status) + logger.debug("BENC: executor scale out has returned") + + # mypy - remove this if statement: block_ids is always a list according to the types. + # and so the else clause was failing with unreachable code. And this removed `if` + # would always fire, if that type annotation is true. + logger.debug(f"BENC: there were some block ids, {block_ids}, which will now be set to pending") + new_status = {} + for block_id in block_ids: + new_status[block_id] = JobStatus(JobState.PENDING) + self.send_monitoring_info(new_status) + self._status.update(new_status) return block_ids def __repr__(self) -> str: @@ -106,9 +123,15 @@ def __init__(self, dfk: "parsl.dataflow.dflow.DataFlowKernel"): self._strategy = Strategy(dfk) self._error_handler = JobErrorHandler() - def poll(self): + def poll(self) -> None: self._update_state() - self._error_handler.run(self._poll_items) + + # List is invariant, and the type of _poll_items if List[PollItem] + # but run wants a list of ExecutorStatus. + # This cast should be safe *if* .run does not break the reason that + # List is invariant, which is that it does not add anything into the + # the list (otherwise, List[PollItem] might end up with ExecutorStatus not-PollItems in it. + self._error_handler.run(cast(List[ExecutorStatus], self._poll_items)) self._strategy.strategize(self._poll_items) def _update_state(self) -> None: diff --git a/parsl/dataflow/strategy.py b/parsl/dataflow/strategy.py index 3227fe2a7a..ae1f29dbaa 100644 --- a/parsl/dataflow/strategy.py +++ b/parsl/dataflow/strategy.py @@ -1,19 +1,47 @@ +from __future__ import annotations import logging import time import math import warnings from typing import List +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from parsl.dataflow.dflow import DataFlowKernel + from parsl.dataflow.job_status_poller import PollItem + +from parsl.executors.base import ParslExecutor, HasConnectedWorkers, HasOutstanding + +from typing import Dict +from typing import Callable +from typing import Optional +from typing import Sequence +from typing_extensions import TypedDict + +# this is used for testing a class to decide how to +# print a status line. That might be better done inside +# the executor class (i..e put the class specific behaviour +# inside the class, rather than testing class instance-ness +# here) + +# smells: testing class instance; importing a specific instance +# of a thing that should be generic + + from parsl.dataflow.executor_status import ExecutorStatus from parsl.executors import HighThroughputExecutor from parsl.executors.status_handling import BlockProviderExecutor from parsl.providers.base import JobState -from parsl.process_loggers import wrap_with_logs +# from parsl.process_loggers import wrap_with_logs logger = logging.getLogger(__name__) +class ExecutorIdleness(TypedDict): + idle_since: Optional[float] + + class Strategy: """Scaling strategy. @@ -110,30 +138,41 @@ class Strategy: """ - def __init__(self, dfk): + def __init__(self, dfk: "DataFlowKernel") -> None: """Initialize strategy.""" self.dfk = dfk self.config = dfk.config + + self.executors: Dict[str, ExecutorIdleness] self.executors = {} + self.max_idletime = self.dfk.config.max_idletime for e in self.dfk.config.executors: self.executors[e.label] = {'idle_since': None} + self.strategies: Dict[Optional[str], Callable] self.strategies = {None: self._strategy_noop, 'none': self._strategy_noop, 'simple': self._strategy_simple, - 'htex_auto_scale': self._strategy_htex_auto_scale} + 'htex_auto_scale': self._strategy_htex_auto_scale + } if self.config.strategy is None: warnings.warn("literal None for strategy choice is deprecated. Use string 'none' instead.", DeprecationWarning) + # mypy note: with mypy 0.761, the type of self.strategize is + # correctly revealed inside this module, but isn't carried over + # when Strategy is used in other modules unless this specific + # type annotation is used. + + self.strategize: Callable self.strategize = self.strategies[self.config.strategy] logger.debug("Scaling strategy: {0}".format(self.config.strategy)) - def add_executors(self, executors): + def add_executors(self, executors: Sequence[ParslExecutor]) -> None: for executor in executors: self.executors[executor.label] = {'idle_since': None} @@ -142,10 +181,10 @@ def _strategy_noop(self, status: List[ExecutorStatus]) -> None: """ logger.debug("strategy_noop: doing nothing") - def _strategy_simple(self, status_list) -> None: + def _strategy_simple(self, status_list: "List[PollItem]") -> None: self._general_strategy(status_list, strategy_type='simple') - def _strategy_htex_auto_scale(self, status_list) -> None: + def _strategy_htex_auto_scale(self, status_list: "List[PollItem]") -> None: """HTEX specific auto scaling strategy This strategy works only for HTEX. This strategy will scale out by @@ -162,8 +201,10 @@ def _strategy_htex_auto_scale(self, status_list) -> None: """ self._general_strategy(status_list, strategy_type='htex') - @wrap_with_logs - def _general_strategy(self, status_list, *, strategy_type): + # can't do wrap with logs until I learn about paramspecs, because wrap_with_logs + # is not tightly typed enough to be allowed in this module yet. + # @wrap_with_logs + def _general_strategy(self, status_list: "List[PollItem]", strategy_type: str) -> None: logger.debug(f"general strategy starting with strategy_type {strategy_type} for {len(status_list)} executors") for exec_status in status_list: @@ -175,15 +216,21 @@ def _general_strategy(self, status_list, *, strategy_type): logger.debug(f"Strategizing for executor {label}") # Tasks that are either pending completion + assert isinstance(executor, HasOutstanding) active_tasks = executor.outstanding status = exec_status.status + # The provider might not even be defined -- what's the behaviour in + # that case? + if executor.provider is None: + logger.error("Trying to strategize an executor that has no provider") + continue + # FIXME we need to handle case where provider does not define these # FIXME probably more of this logic should be moved to the provider min_blocks = executor.provider.min_blocks max_blocks = executor.provider.max_blocks - tasks_per_node = executor.workers_per_node nodes_per_block = executor.provider.nodes_per_block parallelism = executor.provider.parallelism @@ -191,11 +238,17 @@ def _general_strategy(self, status_list, *, strategy_type): running = sum([1 for x in status.values() if x.state == JobState.RUNNING]) pending = sum([1 for x in status.values() if x.state == JobState.PENDING]) active_blocks = running + pending - active_slots = active_blocks * tasks_per_node * nodes_per_block - logger.debug(f"Slot ratio calculation: active_slots = {active_slots}, active_tasks = {active_tasks}") + # TODO: if this isinstance doesn't fire, tasks_per_node and active_slots won't be + # set this iteration and either will be unset or will contain a previous executor's value. + # in both cases, this is wrong. but apparently mypy doesn't notice. + + if isinstance(executor, HasConnectedWorkers): + tasks_per_node = executor.workers_per_node + + active_slots = active_blocks * tasks_per_node * nodes_per_block + logger.debug(f"Slot ratio calculation: active_slots = {active_slots}, active_tasks = {active_tasks}") - if hasattr(executor, 'connected_workers'): logger.debug('Executor {} has {} active tasks, {}/{} running/pending blocks, and {} connected workers'.format( label, active_tasks, running, pending, executor.connected_workers)) else: @@ -227,9 +280,18 @@ def _general_strategy(self, status_list, *, strategy_type): logger.debug(f"Starting idle timer for executor. If idle time exceeds {self.max_idletime}s, blocks will be scaled in") self.executors[executor.label]['idle_since'] = time.time() + # ... this could be None, type-wise. So why aren't we seeing errors here? + # probably becaues usually if this is None, it will be because active_tasks>0, + # (although I can't see a clear proof that this will always be the case: + # could that setting to None have happened on a previous iteration?) + + # if idle_since is None, then that means not idle, which means should not + # go down the scale_in path idle_since = self.executors[executor.label]['idle_since'] - idle_duration = time.time() - idle_since - if idle_duration > self.max_idletime: + if idle_since is not None and (time.time() - idle_since) > self.max_idletime: + # restored this separate calculation even though making a single one + # ahead of time is better... + idle_duration = time.time() - idle_since # We have resources idle for the max duration, # we have to scale_in now. logger.debug(f"Idle time has reached {self.max_idletime}s for executor {label}; scaling in") diff --git a/parsl/executors/base.py b/parsl/executors/base.py index 1b03f61c7d..3ae8ab718c 100644 --- a/parsl/executors/base.py +++ b/parsl/executors/base.py @@ -1,8 +1,14 @@ -from abc import ABCMeta, abstractmethod +from __future__ import annotations +from abc import ABCMeta, abstractmethod, abstractproperty from concurrent.futures import Future -from typing import Any, Callable, Dict, Optional, List +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union +from parsl.data_provider.staging import Staging + +# for type checking: +from parsl.providers.base import ExecutionProvider from parsl.providers.base import JobStatus +from typing_extensions import runtime_checkable, Protocol import parsl # noqa F401 @@ -30,26 +36,45 @@ class ParslExecutor(metaclass=ABCMeta): An executor may optionally expose: - storage_access: List[parsl.data_provider.staging.Staging] - a list of staging + storage_access: Sequence[parsl.data_provider.staging.Staging] - a sequence of staging providers that will be used for file staging. In the absence of this attribute, or if this attribute is `None`, then a default value of ``parsl.data_provider.staging.default_staging`` will be used by the staging code. - - Typechecker note: Ideally storage_access would be declared on executor - __init__ methods as List[Staging] - however, lists are by default - invariant, not co-variant, and it looks like @typeguard cannot be - persuaded otherwise. So if you're implementing an executor and want to - @typeguard the constructor, you'll have to use List[Any] here. """ - label: str = "undefined" - radio_mode: str = "udp" - - def __enter__(self): + # mypy doesn't actually check that the below are defined by + # concrete subclasses - see github.com/python/mypy/issues/4426 + # and maybe PEP-544 Protocols + + def __init__(self) -> None: + self.label: str + self.radio_mode: str = "udp" + + self.provider: Optional[ExecutionProvider] = None + # this is wrong here. eg thread local executor has no provider. + # perhaps its better attached to the block scaling provider? + # cross-ref with notes of @property provider() in the + # nostatushandlingexecutor. + + # i'm not particularly happy with this default, + # probably would be better specified via an __init__ + # as a mandatory parameter + self.managed: bool = False + + # there's an abstraction problem here - what kind of executor should + # statically have this? for now I'll implement a protocol and assert + # the protocol holds, wherever the code makes that assumption. + # self.outstanding: int = None # what is this? used by strategy + self.working_dir: Optional[str] = None + self.storage_access: Optional[Sequence[Staging]] = None + self.run_id: Optional[str] = None + + def __enter__(self) -> ParslExecutor: return self - def __exit__(self, exc_type, exc_val, exc_tb): + # too lazy to figure out what the three Anys here should be + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: self.shutdown() return False @@ -62,9 +87,8 @@ def start(self) -> Optional[List[str]]: pass @abstractmethod - def submit(self, func: Callable, resource_specification: Dict[str, Any], *args: Any, **kwargs: Any) -> Future: + def submit(self, func: Callable, resource_specification: Dict[str, Any], *args: Any, **kwargs: Dict[str, Any]) -> Future: """Submit. - The executor can optionally set a parsl_executor_task_id attribute on the Future that it returns, and in that case, parsl will log a relationship between the executor's task ID and parsl level try/task @@ -94,6 +118,11 @@ def scale_in(self, blocks: int) -> List[str]: which will have the scaling methods, scale_in itself should be a coroutine, since scaling tasks can be slow. + MYPY branch notes: the scale in calls in strategy expect there to be many + more parameters to this. This maybe could be resolved by treating a + status providing executor as more generally a strategy-scalable + executor, and having strategies statically typed to work on those. + :return: A list of block ids corresponding to the blocks that were removed. """ pass @@ -109,6 +138,13 @@ def shutdown(self) -> bool: def create_monitoring_info(self, status: Dict[str, JobStatus]) -> List[object]: """Create a monitoring message for each block based on the poll status. + TODO: block_id_type should be an enumerated list of valid strings, rather than all strings + + TODO: there shouldn't be any default values for this - when it is invoked, it should be explicit which is needed? + Neither seems more natural to me than the other. + + TODO: internal vs external should be more clearly documented here + :return: a list of dictionaries mapping to the info of each block """ return [] @@ -183,7 +219,7 @@ def handle_errors(self, error_handler: "parsl.dataflow.job_error_handler.JobErro pass @abstractmethod - def set_bad_state_and_fail_all(self, exception: Exception): + def set_bad_state_and_fail_all(self, exception: Exception) -> None: """Allows external error handlers to mark this executor as irrecoverably bad and cause all tasks submitted to it now and in the future to fail. The executor is responsible for checking :method:bad_state_is_set() in the :method:submit() method and raising the @@ -235,3 +271,25 @@ def hub_port(self) -> Optional[int]: @hub_port.setter def hub_port(self, value: Optional[int]) -> None: self._hub_port = value + + +@runtime_checkable +class HasConnectedWorkers(Protocol): + """A marker type to indicate that the executor has a count of connected workers. This maybe should merge into the block executor?""" + connected_workers: int + + @abstractproperty + def workers_per_node(self) -> Union[int, float]: + pass + + +@runtime_checkable +class HasOutstanding(Protocol): + """A marker type to indicate that the executor has a count of outstanding tasks. This maybe should merge into the block executor?""" + outstanding: int + + +class FutureWithTaskID(Future): + def __init__(self, task_id: str) -> None: + super().__init__() + self.parsl_executor_task_id = task_id diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index c75ec20445..18bbf8d745 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -8,12 +8,13 @@ import warnings from multiprocessing import Queue from typing import Dict, Sequence # noqa F401 (used in type annotation) -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Any import math from parsl.serialize import pack_apply_message, deserialize from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.high_throughput import zmq_pipes +from parsl.executors.base import HasConnectedWorkers, FutureWithTaskID from parsl.executors.high_throughput import interchange from parsl.executors.errors import ( BadMessage, ScalingFailed, @@ -27,7 +28,7 @@ from parsl.addresses import get_all_addresses from parsl.process_loggers import wrap_with_logs -from parsl.multiprocessing import ForkProcess +from parsl.multiprocessing import forkProcess from parsl.utils import RepresentationMixin from parsl.providers import LocalProvider @@ -36,7 +37,7 @@ _start_methods = ['fork', 'spawn', 'thread'] -class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin): +class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin, HasConnectedWorkers): """Executor designed for cluster-scale The HighThroughputExecutor system has the following components: @@ -195,7 +196,7 @@ def __init__(self, worker_ports: Optional[Tuple[int, int]] = None, worker_port_range: Optional[Tuple[int, int]] = (54000, 55000), interchange_port_range: Optional[Tuple[int, int]] = (55000, 56000), - storage_access: Optional[List[Staging]] = None, + storage_access: Optional[Sequence[Staging]] = None, working_dir: Optional[str] = None, worker_debug: bool = False, cores_per_worker: float = 1.0, @@ -212,11 +213,12 @@ def __init__(self, worker_logdir_root: Optional[str] = None, block_error_handler: bool = True): + self._queue_management_thread: Optional[threading.Thread] + logger.debug("Initializing HighThroughputExecutor") BlockProviderExecutor.__init__(self, provider=provider, block_error_handler=block_error_handler) self.label = label - self.launch_cmd = launch_cmd self.worker_debug = worker_debug self.storage_access = storage_access self.working_dir = working_dir @@ -234,12 +236,12 @@ def __init__(self, mem_slots = max_workers cpu_slots = max_workers - if hasattr(self.provider, 'mem_per_node') and \ + if isinstance(self.provider, ExecutionProvider) and \ self.provider.mem_per_node is not None and \ mem_per_worker is not None and \ mem_per_worker > 0: mem_slots = math.floor(self.provider.mem_per_node / mem_per_worker) - if hasattr(self.provider, 'cores_per_node') and \ + if isinstance(self.provider, ExecutionProvider) and \ self.provider.cores_per_node is not None: cpu_slots = math.floor(self.provider.cores_per_node / cores_per_worker) @@ -271,6 +273,7 @@ def __init__(self, self.run_id = None # set to the correct run_id in dfk self.hub_address = None # set to the correct hub address in dfk self.hub_port = None # set to the correct hub port in dfk + self.worker_ports = worker_ports self.worker_port_range = worker_port_range self.interchange_port_range = interchange_port_range @@ -281,7 +284,11 @@ def __init__(self, self.worker_logdir_root = worker_logdir_root self.cpu_affinity = cpu_affinity - if not launch_cmd: + self._executor_exception = None + + if launch_cmd: + self.launch_cmd = launch_cmd + else: self.launch_cmd = ("process_worker_pool.py {debug} {max_workers} " "-a {addresses} " "-p {prefetch_capacity} " @@ -301,7 +308,7 @@ def __init__(self, radio_mode = "htex" - def initialize_scaling(self): + def initialize_scaling(self) -> List[str]: """ Compose the launch command and call the scale_out This should be implemented in the child classes to take care of @@ -317,6 +324,8 @@ def initialize_scaling(self): if self.worker_logdir_root is not None: worker_logdir = "{}/{}".format(self.worker_logdir_root, self.label) + assert self.provider is not None + l_cmd = self.launch_cmd.format(debug=debug_opts, prefetch_capacity=self.prefetch_capacity, address_probe_timeout_string=address_probe_timeout_string, @@ -340,7 +349,7 @@ def initialize_scaling(self): logger.debug("Starting HighThroughputExecutor with provider:\n%s", self.provider) # TODO: why is this a provider property? - block_ids = [] + block_ids = [] # type: List[str] if hasattr(self.provider, 'init_blocks'): try: block_ids = self.scale_out(blocks=self.provider.init_blocks) @@ -349,7 +358,7 @@ def initialize_scaling(self): raise e return block_ids - def start(self): + def start(self) -> Optional[List[str]]: """Create the Interchange process and connect to it. """ self.outgoing_q = zmq_pipes.TasksOutgoing("127.0.0.1", self.interchange_port_range) @@ -368,7 +377,7 @@ def start(self): return block_ids @wrap_with_logs - def _queue_management_worker(self): + def _queue_management_worker(self) -> None: """Listen to the queue for task status messages and handle them. Depending on the message, tasks will be updated with results, exceptions, @@ -467,14 +476,14 @@ def _queue_management_worker(self): break logger.info("queue management worker finished") - def _start_local_interchange_process(self): + def _start_local_interchange_process(self) -> None: """ Starts the interchange process locally Starts the interchange process locally and uses an internal command queue to get the worker task and result ports that the interchange has bound to. """ - comm_q = Queue(maxsize=10) - self.interchange_proc = ForkProcess(target=interchange.starter, + comm_q = Queue(maxsize=10) # type: Queue[Any] + self.interchange_proc = forkProcess(target=interchange.starter, args=(comm_q,), kwargs={"client_ports": (self.outgoing_q.port, self.incoming_q.port, @@ -498,7 +507,7 @@ def _start_local_interchange_process(self): logger.error("Interchange has not completed initialization in 120s. Aborting") raise Exception("Interchange failed to start") - def _start_queue_management_thread(self): + def _start_queue_management_thread(self) -> None: """Method to start the management thread as a daemon. Checks if a thread already exists, then starts it. @@ -514,7 +523,7 @@ def _start_queue_management_thread(self): else: logger.error("Management thread already exists, returning") - def hold_worker(self, worker_id): + def hold_worker(self, worker_id: str) -> None: """Puts a worker on hold, preventing scheduling of additional tasks to it. This is called "hold" mostly because this only stops scheduling of tasks, @@ -526,9 +535,8 @@ def hold_worker(self, worker_id): worker_id : str Worker id to be put on hold """ - c = self.command_client.run("HOLD_WORKER;{}".format(worker_id)) + self.command_client.run("HOLD_WORKER;{}".format(worker_id)) logger.debug("Sent hold request to manager: {}".format(worker_id)) - return c @property def outstanding(self): @@ -544,7 +552,7 @@ def connected_managers(self): managers = self.command_client.run("MANAGERS") return managers - def _hold_block(self, block_id): + def _hold_block(self, block_id: str) -> None: """ Sends hold command to all managers which are in a specific block Parameters @@ -560,7 +568,7 @@ def _hold_block(self, block_id): logger.debug("Sending hold to manager: {}".format(manager['manager'])) self.hold_worker(manager['manager']) - def submit(self, func, resource_specification, *args, **kwargs): + def submit(self, func, resource_specification, *args, **kwargs) -> "Future[Any]": """Submits work to the outgoing_q. The outgoing_q is an external process listens on this @@ -584,7 +592,10 @@ def submit(self, func, resource_specification, *args, **kwargs): raise UnsupportedFeatureError('resource specification', 'HighThroughput Executor', 'WorkQueue Executor') if self.bad_state_is_set: - raise self.executor_exception + if self.executor_exception is None: + raise ValueError("Executor is in bad state, but no exception recorded") + else: + raise self.executor_exception self._task_counter += 1 task_id = self._task_counter @@ -595,8 +606,7 @@ def submit(self, func, resource_specification, *args, **kwargs): args_to_print = tuple([arg if len(repr(arg)) < 100 else (repr(arg)[:100] + '...') for arg in args]) logger.debug("Pushing function {} to queue with args {}".format(func, args_to_print)) - fut = Future() - fut.parsl_executor_task_id = task_id + fut: Future = FutureWithTaskID(str(task_id)) self.tasks[task_id] = fut try: @@ -620,7 +630,7 @@ def create_monitoring_info(self, status): """ msg = [] for bid, s in status.items(): - d = {} + d: Dict[str, Any] = {} d['run_id'] = self.run_id d['status'] = s.status_name d['timestamp'] = datetime.datetime.now() @@ -634,13 +644,15 @@ def create_monitoring_info(self, status): def workers_per_node(self) -> Union[int, float]: return self._workers_per_node - def scale_in(self, blocks=None, block_ids=[], force=True, max_idletime=None): + def scale_in(self, blocks: Optional[int] = None, block_ids: List[str] = [], force: bool = True, max_idletime: Optional[float] = None) -> List[str]: """Scale in the number of active blocks by specified amount. The scale in method here is very rude. It doesn't give the workers the opportunity to finish current tasks or cleanup. This is tracked in issue #530 + Exactly one of blocks or block_ids must be specified. + Parameters ---------- @@ -666,9 +678,21 @@ def scale_in(self, blocks=None, block_ids=[], force=True, max_idletime=None): List of job_ids marked for termination """ logger.debug(f"Scale in called, blocks={blocks}, block_ids={block_ids}") + + assert (block_ids != []) ^ (blocks is not None), "Exactly one of blocks or block IDs must be specified" + assert self.provider is not None + + block_ids_to_kill: List[str] if block_ids: + # these asserts are slightly different than treating + # block_ids as a bool: they distinguish the empty + # list [] differently. + assert block_ids != [] + assert blocks is None block_ids_to_kill = block_ids else: + assert block_ids == [] + assert blocks is not None managers = self.connected_managers() block_info = {} # block id -> list( tasks, idle duration ) for manager in managers: diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 9a42254603..6292068aa7 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -308,7 +308,7 @@ def _command_server(self): elif command_req.startswith("HOLD_WORKER"): cmd, s_manager = command_req.split(';') manager_id = s_manager.encode('utf-8') - logger.info("Received HOLD_WORKER for {}".format(manager_id)) + logger.info("Received HOLD_WORKER for {!r}".format(manager_id)) if manager_id in self._ready_managers: m = self._ready_managers[manager_id] m['active'] = False @@ -398,9 +398,9 @@ def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill msg = json.loads(message[1].decode('utf-8')) reg_flag = True except Exception: - logger.warning("Got Exception reading registration message from manager: {}".format( + logger.warning("Got Exception reading registration message from manager: {!r}".format( manager_id), exc_info=True) - logger.debug("Message: \n{}\n".format(message[1])) + logger.debug("Message: \n{!r}\n".format(message[1])) else: # We set up an entry only if registration works correctly self._ready_managers[manager_id] = {'last_heartbeat': time.time(), @@ -412,15 +412,15 @@ def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill 'tasks': []} if reg_flag is True: interesting_managers.add(manager_id) - logger.info("Adding manager: {} to ready queue".format(manager_id)) + logger.info("Adding manager: {!r} to ready queue".format(manager_id)) m = self._ready_managers[manager_id] m.update(msg) - logger.info("Registration info for manager {}: {}".format(manager_id, msg)) + logger.info("Registration info for manager {!r}: {}".format(manager_id, msg)) self._send_monitoring_info(hub_channel, m) if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or msg['parsl_v'] != self.current_platform['parsl_v']): - logger.error("Manager {} has incompatible version info with the interchange".format(manager_id)) + logger.error("Manager {!r} has incompatible version info with the interchange".format(manager_id)) logger.debug("Setting kill event") kill_event.set() e = VersionMismatch("py.v={} parsl.v={}".format(self.current_platform['python_v'].rsplit(".", 1)[0], @@ -433,19 +433,19 @@ def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill self.results_outgoing.send(pkl_package) logger.error("Sent failure reports, shutting down interchange") else: - logger.info("Manager {} has compatible Parsl version {}".format(manager_id, msg['parsl_v'])) - logger.info("Manager {} has compatible Python version {}".format(manager_id, - msg['python_v'].rsplit(".", 1)[0])) + logger.info("Manager {!r} has compatible Parsl version {}".format(manager_id, msg['parsl_v'])) + logger.info("Manager {!r} has compatible Python version {}".format(manager_id, + msg['python_v'].rsplit(".", 1)[0])) else: # Registration has failed. - logger.debug("Suppressing bad registration from manager: {}".format( + logger.debug("Suppressing bad registration from manager: {!r}".format( manager_id)) else: tasks_requested = int.from_bytes(message[1], "little") self._ready_managers[manager_id]['last_heartbeat'] = time.time() if tasks_requested == HEARTBEAT_CODE: - logger.debug("Manager {} sent heartbeat via tasks connection".format(manager_id)) + logger.debug("Manager {!r} sent heartbeat via tasks connection".format(manager_id)) self.task_outgoing.send_multipart([manager_id, b'', PKL_HEARTBEAT_CODE]) else: logger.error("Unexpected non-heartbeat message received from manager {}") @@ -499,9 +499,9 @@ def process_results_incoming(self, interesting_managers, hub_channel): logger.debug("entering results_incoming section") manager_id, *all_messages = self.results_incoming.recv_multipart() if manager_id not in self._ready_managers: - logger.warning("Received a result from a un-registered manager: {}".format(manager_id)) + logger.warning("Received a result from a un-registered manager: {!r}".format(manager_id)) else: - logger.debug(f"Got {len(all_messages)} result items in batch from manager {manager_id}") + logger.debug(f"Got {len(all_messages)} result items in batch from manager {manager_id!r}") b_messages = [] @@ -513,7 +513,7 @@ def process_results_incoming(self, interesting_managers, hub_channel): elif r['type'] == 'monitoring': hub_channel.send_pyobj(r['payload']) elif r['type'] == 'heartbeat': - logger.debug(f"Manager {manager_id} sent heartbeat via results connection") + logger.debug(f"Manager {manager_id!r} sent heartbeat via results connection") b_messages.append((p_message, r)) else: logger.error("Interchange discarding result_queue message of unknown type: {}".format(r['type'])) @@ -525,11 +525,11 @@ def process_results_incoming(self, interesting_managers, hub_channel): if r['type'] == 'result': got_result = True try: - logger.debug(f"Removing task {r['task_id']} from manager record {manager_id}") + logger.debug(f"Removing task {r['task_id']} from manager record {manager_id!r}") m['tasks'].remove(r['task_id']) except Exception: # If we reach here, there's something very wrong. - logger.exception("Ignoring exception removing task_id {} for manager {} with task list {}".format( + logger.exception("Ignoring exception removing task_id {} for manager {!r} with task list {}".format( r['task_id'], manager_id, m['tasks'])) @@ -543,7 +543,7 @@ def process_results_incoming(self, interesting_managers, hub_channel): self.results_outgoing.send_multipart(b_messages_to_send) logger.debug("Sent messages on results_outgoing") - logger.debug(f"Current tasks on manager {manager_id}: {m['tasks']}") + logger.debug(f"Current tasks on manager {manager_id!r}: {m['tasks']}") if len(m['tasks']) == 0 and m['idle_since'] is None: m['idle_since'] = time.time() @@ -560,7 +560,7 @@ def expire_bad_managers(self, interesting_managers, hub_channel): time.time() - m['last_heartbeat'] > self.heartbeat_threshold] for (manager_id, m) in bad_managers: logger.debug("Last: {} Current: {}".format(m['last_heartbeat'], time.time())) - logger.warning(f"Too many heartbeats missed for manager {manager_id} - removing manager") + logger.warning(f"Too many heartbeats missed for manager {manager_id!r} - removing manager") if m['active']: m['active'] = False self._send_monitoring_info(hub_channel, m) diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 7892070ef2..d0d7d36a3b 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -19,8 +19,9 @@ import psutil import multiprocessing -from parsl.process_loggers import wrap_with_logs +from typing import Any, Dict +from parsl.process_loggers import wrap_with_logs from parsl.version import VERSION as PARSL_VERSION from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.high_throughput.errors import WorkerLost @@ -28,10 +29,13 @@ from parsl.multiprocessing import ForkProcess as mpForkProcess from parsl.multiprocessing import SpawnProcess as mpSpawnProcess -from parsl.multiprocessing import SizedQueue as mpQueue +from parsl.multiprocessing import sizedQueue +from multiprocessing import Queue from parsl.serialize import unpack_apply_message, serialize +logger = logging.getLogger("parsl") + HEARTBEAT_CODE = (2 ** 32) - 1 @@ -188,6 +192,7 @@ def __init__(self, # Determine which start method to use start_method = start_method.lower() + self.mpProcess: Any # some protocol abstraction of two process types or thread... if start_method == "fork": self.mpProcess = mpForkProcess elif start_method == "spawn": @@ -197,9 +202,9 @@ def __init__(self, else: raise ValueError(f'HTEx does not support start method: "{start_method}"') - self.pending_task_queue = mpQueue() - self.pending_result_queue = mpQueue() - self.ready_worker_queue = mpQueue() + self.pending_task_queue = sizedQueue() # type: Queue[Any] + self.pending_result_queue = sizedQueue() # type: Queue[Any] + self.ready_worker_queue = sizedQueue() # type: Queue[Any] self.max_queue_size = self.prefetch_capacity + self.worker_count @@ -413,16 +418,25 @@ def worker_watchdog(self, kill_event): logger.critical("Exiting") - def start(self): + def start(self) -> None: """ Start the worker processes. TODO: Move task receiving to a thread """ start = time.time() self._kill_event = threading.Event() + + # When upgrading from mypy 0.961 to 0.981, this change happens: + + # multiprocessing.Manager().dict() according to mypy, does not + # return a Dict, but instead a multiprocessing.managers.DictProxy + # parsl/executors/high_throughput/process_worker_pool.py:416: note: Revealed type is "multiprocessing.managers.DictProxy[Any, Any]" + # + # but this type inference gets figured out, so no need for explicit annotation, + # I think self._tasks_in_progress = multiprocessing.Manager().dict() - self.procs = {} + self.procs = {} # type: Dict[Any, Any] for worker_id in range(self.worker_count): p = self.mpProcess(target=worker, args=(worker_id, @@ -508,7 +522,7 @@ def execute_task(bufs): @wrap_with_logs(target="worker_log") -def worker(worker_id, pool_id, pool_size, task_queue, result_queue, worker_queue, tasks_in_progress, cpu_affinity, accelerator: Optional[str]): +def worker(worker_id, pool_id, pool_size, task_queue, result_queue, worker_queue, tasks_in_progress, cpu_affinity, accelerator: Optional[str]) -> None: """ Put request token into queue @@ -594,7 +608,7 @@ def worker(worker_id, pool_id, pool_size, task_queue, result_queue, worker_queue try: result = execute_task(req['buffer']) - serialized_result = serialize(result, buffer_threshold=1e6) + serialized_result = serialize(result, buffer_threshold=int(1e6)) except Exception as e: logger.info('Caught an exception: {}'.format(e)) result_package = {'type': 'result', 'task_id': tid, 'exception': serialize(RemoteExceptionWrapper(*sys.exc_info()))} diff --git a/parsl/executors/status_handling.py b/parsl/executors/status_handling.py index 4215e1fc9a..fc4c6974b5 100644 --- a/parsl/executors/status_handling.py +++ b/parsl/executors/status_handling.py @@ -47,7 +47,9 @@ def __init__(self, *, provider: ExecutionProvider, block_error_handler: bool): super().__init__() + # TODO: untangle having two provider attributes self._provider = provider + self.provider = provider self.block_error_handler = block_error_handler # errors can happen during the submit call to the provider; this is used # to keep track of such errors so that they can be handled in one place @@ -80,10 +82,11 @@ def _make_status_dict(self, block_ids: List[str], status_list: List[JobStatus]) @property def status_polling_interval(self): - if self._provider is None: - return 0 - else: - return self._provider.status_polling_interval + # this codepath is unreachable because execution provider is always set in init + # if self._provider is None: + # return 0 + # else: + return self._provider.status_polling_interval def _fail_job_async(self, block_id: Any, message: str): """Marks a job that has failed to start but would not otherwise be included in status() @@ -94,12 +97,14 @@ def _fail_job_async(self, block_id: Any, message: str): logger.info(f"Allocated block ID {block_id} for simulated failure") self._simulated_status[block_id] = JobStatus(JobState.FAILED, message) - @abstractproperty - def outstanding(self) -> int: - """This should return the number of tasks that the executor has been given to run (waiting to run, and running now)""" - - raise NotImplementedError("Classes inheriting from BlockProviderExecutor must implement " - "outstanding()") + # There's an abstraction problem here: is this a property of all executors or only executors + # which can be scaled with blocks? + # @abstractproperty + # def outstanding(self) -> int: + # """This should return the number of tasks that the executor has been given to run (waiting to run, and running now)""" + # + # raise NotImplementedError("Classes inheriting from BlockProviderExecutor must implement " + # "outstanding()") def status(self) -> Dict[str, JobStatus]: """Return status of all blocks.""" @@ -140,6 +145,11 @@ def handle_errors(self, error_handler: "parsl.dataflow.job_error_handler.JobErro if not self.block_error_handler: return init_blocks = 3 + + # this code path assumes there is self.provider, but there's no + # type-checking guarantee of that. + assert self.provider is not None # for type checking + if hasattr(self.provider, 'init_blocks'): init_blocks = self.provider.init_blocks if init_blocks < 1: @@ -150,9 +160,11 @@ def handle_errors(self, error_handler: "parsl.dataflow.job_error_handler.JobErro def tasks(self) -> Dict[object, Future]: return self._tasks - @property - def provider(self): - return self._provider + # this is defined as a regular attribute at the superclass level, + # which may or may not be correct. + # @property + # def provider(self): + # return self._provider def _filter_scale_in_ids(self, to_kill, killed): """ Filter out job id's that were not killed @@ -166,7 +178,7 @@ def scale_out(self, blocks: int = 1) -> List[str]: """ if not self.provider: raise ScalingFailed(self, "No execution provider available") - block_ids = [] + block_ids: List[str] = [] # is this true? is a block ID always a string (vs eg a POpen object?) logger.info(f"Scaling out by {blocks} blocks") for i in range(blocks): block_id = str(self._block_id_counter.get_id()) @@ -182,6 +194,13 @@ def scale_out(self, blocks: int = 1) -> List[str]: return block_ids def _launch_block(self, block_id: str) -> Any: + + # there's no static type guarantee that there is a provider here but + # the code assumes there is, so to pass type checking, this assert + # will catch violations of that assumption, that otherwise would appear + # in later references to self.provider + assert self.provider is not None + launch_cmd = self._get_launch_command(block_id) job_id = self.provider.submit(launch_cmd, 1) if job_id: @@ -237,6 +256,9 @@ def handle_errors(self, error_handler: "parsl.dataflow.job_error_handler.JobErro status: Dict[str, JobStatus]) -> None: pass - @property - def provider(self): - return self._provider + # this property seems to be unimplemented and unused - providers only make + # sense in the context of a block handling executor, not executors in + # general + # @property + # def provider(self): + # return self._provider diff --git a/parsl/executors/threads.py b/parsl/executors/threads.py index 124d47b9fa..f2933c5871 100644 --- a/parsl/executors/threads.py +++ b/parsl/executors/threads.py @@ -2,7 +2,7 @@ import typeguard import concurrent.futures as cf -from typing import List, Optional +from typing import List, Optional, Sequence from parsl.data_provider.staging import Staging from parsl.executors.status_handling import NoStatusHandlingExecutor @@ -28,7 +28,7 @@ class ThreadPoolExecutor(NoStatusHandlingExecutor, RepresentationMixin): @typeguard.typechecked def __init__(self, label: str = 'threads', max_threads: int = 2, - thread_name_prefix: str = '', storage_access: Optional[List[Staging]] = None, + thread_name_prefix: str = '', storage_access: Optional[Sequence[Staging]] = None, working_dir: Optional[str] = None): NoStatusHandlingExecutor.__init__(self) self.label = label @@ -60,7 +60,7 @@ def submit(self, func, resource_specification, *args, **kwargs): return self.executor.submit(func, *args, **kwargs) - def scale_out(self, workers=1): + def scale_out(self, workers: int = 1) -> List[str]: """Scales out the number of active workers by 1. This method is notImplemented for threads and will raise the error if called. @@ -71,7 +71,7 @@ def scale_out(self, workers=1): raise NotImplementedError - def scale_in(self, blocks): + def scale_in(self, blocks) -> List[str]: """Scale in the number of active blocks by specified amount. This method is not implemented for threads and will raise the error if called. @@ -95,9 +95,9 @@ def shutdown(self, block=True): """ logger.debug("Shutting down executor, which involves waiting for running tasks to complete") - x = self.executor.shutdown(wait=block) + self.executor.shutdown(wait=block) logger.debug("Done with executor shutdown") - return x + return True def monitor_resources(self): """Resource monitoring sometimes deadlocks when using threads, so this function diff --git a/parsl/executors/workqueue/executor.py b/parsl/executors/workqueue/executor.py index c8272e8bff..56e76438d1 100644 --- a/parsl/executors/workqueue/executor.py +++ b/parsl/executors/workqueue/executor.py @@ -6,7 +6,7 @@ import threading import multiprocessing import logging -from concurrent.futures import Future +from parsl.executors.base import FutureWithTaskID from ctypes import c_bool import tempfile @@ -34,7 +34,7 @@ from parsl.utils import setproctitle import typeguard -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Sequence, Set, Union from parsl.data_provider.staging import Staging from .errors import WorkQueueTaskFailure @@ -226,7 +226,7 @@ def __init__(self, port: int = WORK_QUEUE_DEFAULT_PORT, env: Optional[Dict] = None, shared_fs: bool = False, - storage_access: Optional[List[Staging]] = None, + storage_access: Optional[Sequence[Staging]] = None, use_cache: bool = False, source: bool = False, pack: bool = False, @@ -467,8 +467,7 @@ def submit(self, func, resource_specification, *args, **kwargs): input_files.append(self._register_file(maybe_file)) # Create a Future object and have it be mapped from the task ID in the tasks dictionary - fu = Future() - fu.parsl_executor_task_id = executor_task_id + fu = FutureWithTaskID(str(executor_task_id)) logger.debug("Getting tasks_lock to set WQ-level task entry") with self.tasks_lock: logger.debug("Got tasks_lock to set WQ-level task entry") @@ -654,12 +653,21 @@ def initialize_scaling(self): # Start scaling in/out logger.debug("Starting WorkQueueExecutor with provider: %s", self.provider) self._patch_providers() - if hasattr(self.provider, 'init_blocks'): + + # self.provider always has init_blocks - this check only needs to + # check that there is actually a provider specified. + if self.provider is not None: try: self.scale_out(blocks=self.provider.init_blocks) except Exception as e: logger.error("Initial block scaling out failed: {}".format(e)) raise e + # if hasattr(self.provider, 'init_blocks'): + # try: + # self.scale_out(blocks=self.provider.init_blocks) + # except Exception as e: + # logger.error("Initial block scaling out failed: {}".format(e)) + # raise e @property def outstanding(self) -> int: diff --git a/parsl/launchers/errors.py b/parsl/launchers/errors.py index f05a527792..4ef5a6d26d 100644 --- a/parsl/launchers/errors.py +++ b/parsl/launchers/errors.py @@ -1,13 +1,14 @@ from parsl.errors import ParslError +from parsl.launchers.launchers import Launcher class BadLauncher(ParslError): """Error raised when a non callable object is provider as Launcher """ - def __init__(self, launcher, reason): + def __init__(self, launcher: Launcher, reason: str): self.launcher = launcher self.reason = reason - def __repr__(self): + def __repr__(self) -> str: return "Bad Launcher provided:{0} Reason:{1}".format(self.launcher, self.reason) diff --git a/parsl/monitoring/db_manager.py b/parsl/monitoring/db_manager.py index 76189cd937..19ad56483b 100644 --- a/parsl/monitoring/db_manager.py +++ b/parsl/monitoring/db_manager.py @@ -69,7 +69,7 @@ def __init__(self, def _get_mapper(self, table_obj: Table) -> Mapper: all_mappers = set() - for mapper_registry in mapperlib._all_registries(): # type: ignore + for mapper_registry in mapperlib._all_registries(): # type: ignore[attr-defined] all_mappers.update(mapper_registry.mappers) mapper_gen = ( mapper for mapper in all_mappers diff --git a/parsl/monitoring/monitoring.py b/parsl/monitoring/monitoring.py index 357fb670fb..54912598cb 100644 --- a/parsl/monitoring/monitoring.py +++ b/parsl/monitoring/monitoring.py @@ -7,11 +7,11 @@ import zmq import queue - import parsl.monitoring.remote -from parsl.multiprocessing import ForkProcess, SizedQueue -from multiprocessing import Process, Queue +from multiprocessing import Queue +from parsl.multiprocessing import forkProcess, sizedQueue + from parsl.utils import RepresentationMixin from parsl.process_loggers import wrap_with_logs from parsl.utils import setproctitle @@ -130,7 +130,8 @@ def __init__(self, The time interval, in seconds, at which the monitoring records the resource usage of each task. Default: 30 seconds """ - self.logger = logger + # previously this was set in start() but logger exists at import so it can be set here and remove the optionality of self.logger's type + self.logger = logger # type: logging.Logger # Any is used to disable typechecking on uses of _dfk_channel, # because it is used in the code as if it points to a channel, but @@ -170,14 +171,14 @@ def start(self, run_id: str, run_dir: str) -> int: self.logger.debug("Initializing ZMQ Pipes to client") self.monitoring_hub_active = True - comm_q = SizedQueue(maxsize=10) # type: Queue[Union[Tuple[int, int], str]] - self.exception_q = SizedQueue(maxsize=10) # type: Queue[Tuple[str, str]] - self.priority_msgs = SizedQueue() # type: Queue[Tuple[Any, int]] - self.resource_msgs = SizedQueue() # type: Queue[AddressedMonitoringMessage] - self.node_msgs = SizedQueue() # type: Queue[AddressedMonitoringMessage] - self.block_msgs = SizedQueue() # type: Queue[AddressedMonitoringMessage] + comm_q = sizedQueue(maxsize=10) # type: Queue[Union[Tuple[int, int], str]] + self.exception_q = sizedQueue(maxsize=10) # type: Queue[Tuple[str, str]] + self.priority_msgs = sizedQueue() # type: Queue[Tuple[Any, int]] + self.resource_msgs = sizedQueue() # type: Queue[AddressedMonitoringMessage] + self.node_msgs = sizedQueue() # type: Queue[AddressedMonitoringMessage] + self.block_msgs = sizedQueue() # type: Queue[AddressedMonitoringMessage] - self.router_proc = ForkProcess(target=router_starter, + self.router_proc = forkProcess(target=router_starter, args=(comm_q, self.exception_q, self.priority_msgs, self.node_msgs, self.block_msgs, self.resource_msgs), kwargs={"hub_address": self.hub_address, "hub_port": self.hub_port, @@ -191,7 +192,7 @@ def start(self, run_id: str, run_dir: str) -> int: ) self.router_proc.start() - self.dbm_proc = ForkProcess(target=dbm_starter, + self.dbm_proc = forkProcess(target=dbm_starter, args=(self.exception_q, self.priority_msgs, self.node_msgs, self.block_msgs, self.resource_msgs,), kwargs={"logdir": self.logdir, "logging_level": logging.DEBUG if self.monitoring_debug else logging.INFO, @@ -203,10 +204,10 @@ def start(self, run_id: str, run_dir: str) -> int: self.dbm_proc.start() self.logger.info("Started the router process {} and DBM process {}".format(self.router_proc.pid, self.dbm_proc.pid)) - self.filesystem_proc = Process(target=filesystem_receiver, - args=(self.logdir, self.resource_msgs, run_dir), - name="Monitoring-Filesystem-Process", - daemon=True + self.filesystem_proc = forkProcess(target=filesystem_receiver, + args=(self.logdir, self.resource_msgs, run_dir), + name="Monitoring-Filesystem-Process", + daemon=True ) self.filesystem_proc.start() self.logger.info(f"Started filesystem radio receiver process {self.filesystem_proc.pid}") diff --git a/parsl/monitoring/remote.py b/parsl/monitoring/remote.py index 073b44f713..5f16f5e6df 100644 --- a/parsl/monitoring/remote.py +++ b/parsl/monitoring/remote.py @@ -5,7 +5,7 @@ from functools import wraps from parsl.multiprocessing import ForkProcess -from multiprocessing import Event, Process +from multiprocessing import Event from parsl.process_loggers import wrap_with_logs from parsl.monitoring.message_type import MessageType @@ -57,7 +57,7 @@ def wrapped(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: radio_mode, run_dir) - p: Optional[Process] + p: Optional[ForkProcess] if monitor_resources: # create the monitor process and start pp = ForkProcess(target=monitor, diff --git a/parsl/multiprocessing.py b/parsl/multiprocessing.py index 15c48fbbeb..28c96c6947 100644 --- a/parsl/multiprocessing.py +++ b/parsl/multiprocessing.py @@ -6,14 +6,22 @@ import multiprocessing.queues import platform -from typing import Callable, Type - logger = logging.getLogger(__name__) -# maybe ForkProcess should be: Callable[..., Process] so as to make -# it clear that it returns a Process always to the type checker? -ForkProcess: Type = multiprocessing.get_context('fork').Process -SpawnProcess: Type = multiprocessing.get_context('spawn').Process +ForkProcess = multiprocessing.context.ForkProcess +SpawnProcess = multiprocessing.context.SpawnProcess + + +def forkProcess(*args, **kwargs) -> ForkProcess: + P = multiprocessing.get_context('fork').Process + # reveal_type(P) + return P(*args, **kwargs) + + +def spawnProcess(*args, **kwargs) -> SpawnProcess: + P = multiprocessing.get_context('spawn').Process + # reveal_type(P) + return P(*args, **kwargs) class MacSafeQueue(multiprocessing.queues.Queue): @@ -54,11 +62,10 @@ def empty(self): # SizedQueue should be constructable using the same calling # convention as multiprocessing.Queue but that entire signature # isn't expressible in mypy 0.790 -SizedQueue: Callable[..., multiprocessing.Queue] - +# SizedQueue: Callable[..., multiprocessing.Queue] -if platform.system() != 'Darwin': - import multiprocessing - SizedQueue = multiprocessing.Queue -else: - SizedQueue = MacSafeQueue +def sizedQueue(*args, **kwargs) -> multiprocessing.queues.Queue: + if platform.system() != 'Darwin': + return multiprocessing.Queue(*args, **kwargs) + else: + return MacSafeQueue(*args, **kwargs) diff --git a/parsl/providers/ad_hoc/ad_hoc.py b/parsl/providers/ad_hoc/ad_hoc.py index 904ec3b134..00d5e829d4 100644 --- a/parsl/providers/ad_hoc/ad_hoc.py +++ b/parsl/providers/ad_hoc/ad_hoc.py @@ -4,14 +4,16 @@ from parsl.channels import LocalChannel from parsl.launchers import SimpleLauncher -from parsl.providers.base import ExecutionProvider, JobStatus, JobState +from parsl.providers.base import ExecutionProvider, JobStatus, JobState, MultiChanneled from parsl.providers.errors import ScriptPathError from parsl.utils import RepresentationMixin +from typing import Dict, Any, List + logger = logging.getLogger(__name__) -class AdHocProvider(ExecutionProvider, RepresentationMixin): +class AdHocProvider(ExecutionProvider, MultiChanneled, RepresentationMixin): """ Ad-hoc execution provider This provider is used to provision execution resources over one or more ad hoc nodes @@ -61,7 +63,7 @@ def __init__(self, self.nodes_per_block = 1 # Dictionary that keeps track of jobs, keyed on job_id - self.resources = {} + self.resources = {} # type: Dict[Any, Dict[str, Any]] self.least_loaded = self._least_loaded() logger.debug("AdHoc provider initialized") @@ -183,15 +185,17 @@ def submit(self, command, tasks_per_node, job_name="parsl.adhoc"): if job_id is None: logger.warning("Channel failed to start remote command/retrieve PID") - self.resources[job_id] = {'job_id': job_id, - 'status': JobStatus(JobState.RUNNING), - 'cmd': final_cmd, - 'channel': channel, - 'remote_pid': remote_pid} + d = {'job_id': job_id, + 'status': JobStatus(JobState.RUNNING), + 'cmd': final_cmd, + 'channel': channel, + 'remote_pid': remote_pid} # type: Dict[str, Any] + + self.resources[job_id] = d return job_id - def status(self, job_ids): + def status(self, job_ids: List[Any]) -> List[JobStatus]: """ Get status of the list of jobs with job_ids Parameters diff --git a/parsl/providers/aws/aws.py b/parsl/providers/aws/aws.py index 7212fc362b..3a2dfcc85c 100644 --- a/parsl/providers/aws/aws.py +++ b/parsl/providers/aws/aws.py @@ -3,6 +3,7 @@ import os import time from string import Template +from typing import List from parsl.errors import ConfigurationError from parsl.providers.aws.template import template_string @@ -601,7 +602,7 @@ def get_instance_state(self, instances=None): self.instance_states[instance['InstanceId']] = instance['State']['Name'] return self.instance_states - def status(self, job_ids): + def status(self, job_ids) -> List[JobStatus]: """Get the status of a list of jobs identified by their ids. Parameters diff --git a/parsl/providers/azure/azure.py b/parsl/providers/azure/azure.py index 68c1b155f8..f71c768114 100644 --- a/parsl/providers/azure/azure.py +++ b/parsl/providers/azure/azure.py @@ -3,6 +3,7 @@ import os import time from string import Template +from typing import Any, List from parsl.errors import ConfigurationError from parsl.providers.azure.template import template_string @@ -156,7 +157,7 @@ def __init__(self, self.launcher = launcher self.linger = linger self.resources = {} - self.instances = [] + self.instances = [] # type: List[Any] env_specified = os.getenv("AZURE_CLIENT_ID") is not None and os.getenv( "AZURE_CLIENT_SECRET") is not None and os.getenv( @@ -299,7 +300,7 @@ def submit(self, return vm_info.name - def status(self, job_ids): + def status(self, job_ids) -> List[JobStatus]: """Get the status of a list of jobs identified by their ids. Parameters ---------- diff --git a/parsl/providers/base.py b/parsl/providers/base.py index fea4ae7063..a467def805 100644 --- a/parsl/providers/base.py +++ b/parsl/providers/base.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from enum import IntEnum import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence from parsl.channels.base import Channel @@ -160,6 +160,9 @@ def __init__(self) -> None: self._mem_per_node: Optional[float] = None pass + # TODO: how about make this always return a job ID and must raise an exception on + # failure? Potentially it could make failure path handling simpler? because right now, you + # should catch an exception or look for None and handle them both the same? @abstractmethod def submit(self, command: str, tasks_per_node: int, job_name: str = "parsl.auto") -> object: ''' The submit method takes the command string to be executed upon @@ -180,6 +183,7 @@ def submit(self, command: str, tasks_per_node: int, job_name: str = "parsl.auto" Raises: - ExecutionProviderException or its subclasses + ^ is this true? I think we can raise anything... ''' pass @@ -203,7 +207,7 @@ def status(self, job_ids: List[object]) -> List[JobStatus]: pass @abstractmethod - def cancel(self, job_ids: List[object]) -> List[bool]: + def cancel(self, job_ids: Sequence[object]) -> Sequence[bool]: ''' Cancels the resources identified by the job_ids provided by the user. Args: @@ -233,8 +237,14 @@ def mem_per_node(self) -> Optional[float]: If this property is set, executors may use it to calculate how many tasks can run concurrently per node. + + This property, and cores_per_node, might become a HasCoresMem protocol, on + the way to detangling what is optional? """ - return self._mem_per_node + if hasattr(self, "_mem_per_node"): + return self._mem_per_node + else: + return None @mem_per_node.setter def mem_per_node(self, value: float) -> None: @@ -251,7 +261,10 @@ def cores_per_node(self) -> Optional[int]: If this property is set, executors may use it to calculate how many tasks can run concurrently per node. """ - return self._cores_per_node + if hasattr(self, "_cores_per_node"): + return self._cores_per_node + else: + return None @cores_per_node.setter def cores_per_node(self, value: int) -> None: diff --git a/parsl/providers/cluster_provider.py b/parsl/providers/cluster_provider.py index 8b28bbb021..86c478566d 100644 --- a/parsl/providers/cluster_provider.py +++ b/parsl/providers/cluster_provider.py @@ -4,12 +4,16 @@ from parsl.providers.errors import SchedulerMissingArgs, ScriptPathError from parsl.launchers.errors import BadLauncher -from parsl.providers.base import ExecutionProvider +from parsl.providers.base import ExecutionProvider, Channeled, JobStatus logger = logging.getLogger(__name__) +from typing import Any, Dict, List +from parsl.channels.base import Channel +from parsl.launchers.launchers import Launcher -class ClusterProvider(ExecutionProvider): + +class ClusterProvider(ExecutionProvider, Channeled): """ This class defines behavior common to all cluster/supercompute-style scheduler systems. Parameters @@ -45,16 +49,16 @@ class ClusterProvider(ExecutionProvider): """ def __init__(self, - label, - channel, - nodes_per_block, - init_blocks, - min_blocks, - max_blocks, - parallelism, - walltime, - launcher, - cmd_timeout=10): + label: str, + channel: Channel, + nodes_per_block: int, + init_blocks: int, + min_blocks: int, + max_blocks: int, + parallelism: float, # nb. the member field for this is used by strategy, so maybe this should be exposed at the layer above as a property? + walltime: str, + launcher: Launcher, + cmd_timeout: int = 10) -> None: self._label = label self.channel = channel @@ -66,15 +70,24 @@ def __init__(self, self.launcher = launcher self.walltime = walltime self.cmd_timeout = cmd_timeout - if not callable(self.launcher): + + # TODO: this test should be for being a launcher, not being callable + if not isinstance(self.launcher, Launcher): raise(BadLauncher(self.launcher, - "Launcher for executor: {} is of type: {}. Expects a parsl.launcher.launcher.Launcher or callable".format( - label, type(self.launcher)))) + "Launcher for executor: {} is of type: {}. Expects a parsl.launcher.launcher.Launcher".format(label, type(self.launcher)))) self.script_dir = None # Dictionary that keeps track of jobs, keyed on job_id - self.resources = {} + self.resources = {} # type: Dict[Any, Any] + + # This annotation breaks slurm: + # parsl/providers/slurm/slurm.py:201: error: Item "None" of "Optional[str]" has no attribute "split" + # parsl/providers/slurm/slurm.py:207: error: Item "None" of "Optional[str]" has no attribute "strip" + # Theres a dependent type at work here which I can't describe in the type system: + # the optional strs are None when int != 0, for some providers. + # and when int == 0, the optional strs are strs + # def execute_wait(self, cmd, timeout=None) -> Tuple[int, Optional[str], Optional[str]]: def execute_wait(self, cmd, timeout=None): t = self.cmd_timeout @@ -124,7 +137,7 @@ def _write_submit_script(self, template, script_filename, job_name, configs): def _status(self): pass - def status(self, job_ids): + def status(self, job_ids: List[Any]) -> List[JobStatus]: """ Get the status of a list of jobs identified by the job identifiers returned from the submit request. diff --git a/parsl/providers/condor/condor.py b/parsl/providers/condor/condor.py index 7539e3b694..f84e2ef190 100644 --- a/parsl/providers/condor/condor.py +++ b/parsl/providers/condor/condor.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from parsl.channels.base import Channel from parsl.launchers.launchers import Launcher @@ -155,7 +155,7 @@ def _status(self): if job_id in self.resources: self.resources[job_id]['status'] = JobStatus(state) - def status(self, job_ids): + def status(self, job_ids: List[Any]) -> List[JobStatus]: """Get the status of a list of jobs identified by their ids. Parameters @@ -261,7 +261,7 @@ def submit(self, command, tasks_per_node, job_name="parsl.condor"): except Exception as e: raise ScaleOutFailed(self.label, str(e)) - job_id = [] + job_id = [] # type: List[str] if retcode == 0: for line in stdout.split('\n'): diff --git a/parsl/providers/googlecloud/googlecloud.py b/parsl/providers/googlecloud/googlecloud.py index af5252d180..5131420edc 100644 --- a/parsl/providers/googlecloud/googlecloud.py +++ b/parsl/providers/googlecloud/googlecloud.py @@ -1,6 +1,7 @@ import atexit import logging import os +from typing import List from parsl.launchers import SingleNodeLauncher from parsl.providers.base import JobState, JobStatus @@ -118,7 +119,7 @@ def submit(self, command, tasks_per_node, job_name="parsl.gcs"): self.resources[name] = {"job_id": name, "status": JobStatus(translate_table[instance['status']])} return name - def status(self, job_ids): + def status(self, job_ids) -> List[JobStatus]: ''' Get the status of a list of jobs identified by the job identifiers returned from the submit request. diff --git a/parsl/providers/kubernetes/kube.py b/parsl/providers/kubernetes/kube.py index bd9acbeaee..9477d63d36 100644 --- a/parsl/providers/kubernetes/kube.py +++ b/parsl/providers/kubernetes/kube.py @@ -130,7 +130,7 @@ def __init__(self, # Dictionary that keeps track of jobs, keyed on job_id self.resources = {} # type: Dict[object, Dict[str, Any]] - def submit(self, cmd_string, tasks_per_node, job_name="parsl"): + def submit(self, cmd_string, tasks_per_node, job_name="parsl") -> Optional[str]: """ Submit a job Args: - cmd_string :(String) - Name of the container to initiate @@ -166,7 +166,7 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"): return pod_name - def status(self, job_ids): + def status(self, job_ids) -> List[JobStatus]: """ Get the status of a list of jobs identified by the job identifiers returned from the submit request. Args: @@ -180,7 +180,7 @@ def status(self, job_ids): self._status() return [self.resources[jid]['status'] for jid in job_ids] - def cancel(self, job_ids): + def cancel(self, job_ids) -> List[bool]: """ Cancels the jobs specified by a list of job ids Args: job_ids : [ ...] @@ -230,7 +230,7 @@ def _create_pod(self, job_name, port=80, cmd_string=None, - volumes=[]): + volumes=[]) -> None: """ Create a kubernetes pod for the job. Args: - image (string) : Docker image to launch @@ -299,7 +299,7 @@ def _create_pod(self, body=pod) logger.debug("Pod created. status='{0}'".format(str(api_response.status))) - def _delete_pod(self, pod_name): + def _delete_pod(self, pod_name) -> None: """Delete a pod""" api_response = self.kube_client.delete_namespaced_pod(name=pod_name, @@ -308,7 +308,7 @@ def _delete_pod(self, pod_name): logger.debug("Pod deleted. status='{0}'".format(str(api_response.status))) @property - def label(self): + def label(self) -> str: return "kubernetes" @property diff --git a/parsl/providers/local/local.py b/parsl/providers/local/local.py index f827d57efd..b0ae187282 100644 --- a/parsl/providers/local/local.py +++ b/parsl/providers/local/local.py @@ -2,16 +2,20 @@ import os import time +from parsl.channels.base import Channel from parsl.channels import LocalChannel from parsl.launchers import SingleNodeLauncher -from parsl.providers.base import ExecutionProvider, JobState, JobStatus +from parsl.providers.base import Channeled, ExecutionProvider, JobState, JobStatus from parsl.providers.errors import SchedulerMissingArgs, ScriptPathError, SubmitException from parsl.utils import RepresentationMixin +from typing import Any, List + logger = logging.getLogger(__name__) -class LocalProvider(ExecutionProvider, RepresentationMixin): +class LocalProvider(ExecutionProvider, RepresentationMixin, Channeled): + """ Local Execution Provider This provider is used to provide execution resources from the localhost. @@ -34,7 +38,7 @@ class LocalProvider(ExecutionProvider, RepresentationMixin): """ def __init__(self, - channel=LocalChannel(), + channel: Channel = LocalChannel(), nodes_per_block=1, launcher=SingleNodeLauncher(), init_blocks=1, @@ -60,7 +64,7 @@ def __init__(self, # Dictionary that keeps track of jobs, keyed on job_id self.resources = {} - def status(self, job_ids): + def status(self, job_ids: List[Any]) -> List[JobStatus]: ''' Get the status of a list of jobs identified by their ids. Args: @@ -117,22 +121,32 @@ def status(self, job_ids): return [self.resources[jid]['status'] for jid in job_ids] - def _is_alive(self, job_dict): + def _is_alive(self, job_dict) -> bool: retcode, stdout, stderr = self.channel.execute_wait( 'ps -p {} > /dev/null 2> /dev/null; echo "STATUS:$?" '.format( job_dict['remote_pid']), self.cmd_timeout) - for line in stdout.split('\n'): - if line.startswith("STATUS:"): - status = line.split("STATUS:")[1].strip() - if status == "0": - return True - else: - return False + if stdout: + for line in stdout.split('\n'): + if line.startswith("STATUS:"): + status = line.split("STATUS:")[1].strip() + if status == "0": + return True + else: + return False + raise RuntimeError("Hit end of stdout scan without finding STATUS. Unclear what the correct default behaviour is here, so raising exception") + else: + raise RuntimeError("no stdout. Unclear what the correct default behaviour is here, so raising exception.") def _job_file_path(self, script_path: str, suffix: str) -> str: path = '{0}{1}'.format(script_path, suffix) if self._should_move_files(): - path = self.channel.pull_file(path, self.script_dir) + if not self.script_dir: + raise RuntimeError("want to pull_file but script_dir is not defined - unclear what the correct behaviour is so raising exception") + new_path = self.channel.pull_file(path, self.script_dir) + if path is None: + raise RuntimeError("pull_file returned None - unclear what the correct behaviour is so raising exception") + else: + path = new_path return path def _read_job_file(self, script_path: str, suffix: str) -> str: @@ -172,7 +186,7 @@ def _write_submit_script(self, script_string, script_filename): return True - def submit(self, command, tasks_per_node, job_name="parsl.localprovider"): + def submit(self, command: str, tasks_per_node: int, job_name: str = "parsl.localprovider") -> object: ''' Submits the command onto an Local Resource Manager job. Submit returns an ID that corresponds to the task that was just submitted. @@ -208,7 +222,7 @@ def submit(self, command, tasks_per_node, job_name="parsl.localprovider"): self._write_submit_script(wrap_command, script_path) - job_id = None + job_id = None # type: Any remote_pid = None if self._should_move_files(): logger.debug("Pushing start script") @@ -236,10 +250,13 @@ def submit(self, command, tasks_per_node, job_name="parsl.localprovider"): if retcode != 0: raise SubmitException(job_name, "Launch command exited with code {0}".format(retcode), stdout, stderr) - for line in stdout.split('\n'): - if line.startswith("PID:"): - remote_pid = line.split("PID:")[1].strip() - job_id = remote_pid + if stdout: + for line in stdout.split('\n'): + if line.startswith("PID:"): + remote_pid = line.split("PID:")[1].strip() + job_id = remote_pid + else: + logger.debug("no stdout, which would caused a runtime type error splitting stdout. Acting as if stdout has no lines.") if job_id is None: raise SubmitException(job_name, "Channel failed to start remote command/retrieve PID") diff --git a/parsl/providers/slurm/slurm.py b/parsl/providers/slurm/slurm.py index 8d577dedde..d4f5c6c4f9 100644 --- a/parsl/providers/slurm/slurm.py +++ b/parsl/providers/slurm/slurm.py @@ -16,6 +16,8 @@ from parsl.providers.slurm.template import template_string from parsl.utils import RepresentationMixin, wtime_to_minutes +from typing import Any, Dict + logger = logging.getLogger(__name__) translate_table = { @@ -181,7 +183,7 @@ def _status(self): logger.debug("Updating missing job {} to completed status".format(missing_job)) self.resources[missing_job]['status'] = JobStatus(JobState.COMPLETED) - def submit(self, command, tasks_per_node, job_name="parsl.slurm"): + def submit(self, command, tasks_per_node, job_name="parsl.slurm") -> Optional[str]: """Submit the command as a slurm job. Parameters @@ -215,7 +217,7 @@ def submit(self, command, tasks_per_node, job_name="parsl.slurm"): logger.debug("Requesting one block with {} nodes".format(self.nodes_per_block)) - job_config = {} + job_config = {} # type: Dict[str, Any] job_config["submit_script_dir"] = self.channel.script_dir job_config["nodes"] = self.nodes_per_block job_config["tasks_per_node"] = tasks_per_node diff --git a/parsl/providers/torque/torque.py b/parsl/providers/torque/torque.py index e5b8c31d6d..ddc1b66126 100644 --- a/parsl/providers/torque/torque.py +++ b/parsl/providers/torque/torque.py @@ -11,6 +11,10 @@ logger = logging.getLogger(__name__) +from typing import Optional +from parsl.channels.base import Channel +from parsl.launchers.launchers import Launcher + # From the man pages for qstat for PBS/Torque systems translate_table = { 'B': JobState.RUNNING, # This state is returned for running array jobs @@ -68,19 +72,19 @@ class TorqueProvider(ClusterProvider, RepresentationMixin): """ def __init__(self, - channel=LocalChannel(), - account=None, - queue=None, - scheduler_options='', - worker_init='', - nodes_per_block=1, - init_blocks=1, - min_blocks=0, - max_blocks=1, - parallelism=1, - launcher=AprunLauncher(), - walltime="00:20:00", - cmd_timeout=120): + channel: Channel = LocalChannel(), + account: Optional[str] = None, + queue: Optional[str] = None, + scheduler_options: str = '', + worker_init: str = '', + nodes_per_block: int = 1, + init_blocks: int = 1, + min_blocks: int = 0, + max_blocks: int = 1, + parallelism: float = 1, + launcher: Launcher = AprunLauncher(), + walltime: str = "00:20:00", + cmd_timeout: int = 120) -> None: label = 'torque' super().__init__(label, channel, diff --git a/parsl/serialize/base.py b/parsl/serialize/base.py index 70bbe7a18c..fa898c455c 100644 --- a/parsl/serialize/base.py +++ b/parsl/serialize/base.py @@ -2,6 +2,8 @@ import logging import functools +from typing import Any + logger = logging.getLogger(__name__) # GLOBALS @@ -13,18 +15,22 @@ class SerializerBase: """ Adds shared functionality for all serializer implementations """ - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: """ This forces all child classes to register themselves as methods for serializing code or data """ - super().__init_subclass__(*args, **kwargs) + super().__init_subclass__() if cls._for_code: METHODS_MAP_CODE[cls._identifier] = cls if cls._for_data: METHODS_MAP_DATA[cls._identifier] = cls + _identifier: bytes + _for_code: bool + _for_data: bool + @property - def identifier(self): + def identifier(self) -> bytes: """ Get the identifier of the serialization method Returns @@ -33,7 +39,7 @@ def identifier(self): """ return self._identifier - def chomp(self, payload): + def chomp(self, payload: bytes) -> bytes: """ If the payload starts with the identifier, return the remaining block Parameters @@ -43,21 +49,23 @@ def chomp(self, payload): """ s_id, payload = payload.split(b'\n', 1) if (s_id + b'\n') != self.identifier: - raise TypeError("Buffer does not start with parsl.serialize identifier:{}".format(self.identifier)) + raise TypeError("Buffer does not start with parsl.serialize identifier:{!r}".format(self.identifier)) return payload - def enable_caching(self, maxsize=128): + def enable_caching(self, maxsize: int = 128) -> None: """ Add functools.lru_cache onto the serialize, deserialize methods """ - self.serialize = functools.lru_cache(maxsize=maxsize)(self.serialize) - self.deserialize = functools.lru_cache(maxsize=maxsize)(self.deserialize) + # ignore types here because mypy at the moment is not fond of monkeypatching + self.serialize = functools.lru_cache(maxsize=maxsize)(self.serialize) # type: ignore[method-assign] + self.deserialize = functools.lru_cache(maxsize=maxsize)(self.deserialize) # type: ignore[method-assign] + return @abstractmethod - def serialize(self, data): + def serialize(self, data: Any) -> bytes: pass @abstractmethod - def deserialize(self, payload): + def deserialize(self, payload: bytes) -> Any: pass diff --git a/parsl/serialize/concretes.py b/parsl/serialize/concretes.py index 4d1a4425ab..e66ab17a73 100644 --- a/parsl/serialize/concretes.py +++ b/parsl/serialize/concretes.py @@ -5,6 +5,8 @@ logger = logging.getLogger(__name__) from parsl.serialize.base import SerializerBase +from typing import Any + class PickleSerializer(SerializerBase): """ Pickle serialization covers most python objects, with some notable exceptions: @@ -19,11 +21,11 @@ class PickleSerializer(SerializerBase): _for_code = True _for_data = True - def serialize(self, data): + def serialize(self, data: Any) -> bytes: x = pickle.dumps(data) return self.identifier + x - def deserialize(self, payload): + def deserialize(self, payload: bytes) -> Any: chomped = self.chomp(payload) data = pickle.loads(chomped) return data @@ -45,11 +47,11 @@ class DillSerializer(SerializerBase): _for_code = True _for_data = True - def serialize(self, data): + def serialize(self, data: Any) -> bytes: x = dill.dumps(data) return self.identifier + x - def deserialize(self, payload): + def deserialize(self, payload: bytes) -> Any: chomped = self.chomp(payload) data = dill.loads(chomped) return data diff --git a/parsl/serialize/facade.py b/parsl/serialize/facade.py index 45235616ef..c628c5bdc6 100644 --- a/parsl/serialize/facade.py +++ b/parsl/serialize/facade.py @@ -1,7 +1,9 @@ from parsl.serialize.concretes import * # noqa: F403,F401 -from parsl.serialize.base import METHODS_MAP_DATA, METHODS_MAP_CODE +from parsl.serialize.base import METHODS_MAP_DATA, METHODS_MAP_CODE, SerializerBase import logging +from typing import Any, Dict, List, Tuple, Union + logger = logging.getLogger(__name__) @@ -21,11 +23,11 @@ methods_for_data[key] = METHODS_MAP_DATA[key]() -def _list_methods(): +def _list_methods() -> Tuple[Dict[bytes, SerializerBase], Dict[bytes, SerializerBase]]: return methods_for_code, methods_for_data -def pack_apply_message(func, args, kwargs, buffer_threshold=128 * 1e6): +def pack_apply_message(func: Any, args: Any, kwargs: Any, buffer_threshold: int = int(128 * 1e6)) -> bytes: """Serialize and pack function and parameters Parameters @@ -51,53 +53,48 @@ def pack_apply_message(func, args, kwargs, buffer_threshold=128 * 1e6): return packed_buffer -def unpack_apply_message(packed_buffer, user_ns=None, copy=False): +def unpack_apply_message(packed_buffer: bytes, user_ns: Any = None, copy: Any = False) -> List[Any]: """ Unpack and deserialize function and parameters """ return [deserialize(buf) for buf in unpack_buffers(packed_buffer)] -def serialize(obj, buffer_threshold=1e6): +def serialize(obj: Any, buffer_threshold: int = int(1e6)) -> bytes: """ Try available serialization methods one at a time Individual serialization methods might raise a TypeError (eg. if objects are non serializable) This method will raise the exception from the last method that was tried, if all methods fail. """ - serialized = None - serialized_flag = False - last_exception = None + result: Union[bytes, Exception] if callable(obj): for method in methods_for_code.values(): try: - serialized = method.serialize(obj) + result = method.serialize(obj) except Exception as e: - last_exception = e + result = e continue else: - serialized_flag = True break else: for method in methods_for_data.values(): try: - serialized = method.serialize(obj) + result = method.serialize(obj) except Exception as e: - last_exception = e + result = e continue else: - serialized_flag = True break - if serialized_flag is False: - # TODO : Replace with a SerializationError - raise last_exception - - if len(serialized) > buffer_threshold: - logger.warning(f"Serialized object exceeds buffer threshold of {buffer_threshold} bytes, this could cause overflows") - return serialized + if isinstance(result, BaseException): + raise result + else: + if len(result) > buffer_threshold: + logger.warning(f"Serialized object exceeds buffer threshold of {buffer_threshold} bytes, this could cause overflows") + return result -def deserialize(payload): +def deserialize(payload: bytes) -> Any: """ Parameters ---------- @@ -111,16 +108,16 @@ def deserialize(payload): elif header in methods_for_data: result = methods_for_data[header].deserialize(payload) else: - raise TypeError("Invalid header: {} in data payload. Buffer is either corrupt or not created by ParslSerializer".format(header)) + raise TypeError("Invalid header: {!r} in data payload. Buffer is either corrupt or not created by ParslSerializer".format(header)) return result -def pack_buffers(buffers): +def pack_buffers(buffers: List[bytes]) -> bytes: """ Parameters ---------- - buffers : list of \n terminated strings + buffers : list of \n terminated (byte?) strings """ packed = b'' for buf in buffers: @@ -130,11 +127,11 @@ def pack_buffers(buffers): return packed -def unpack_buffers(packed_buffer): +def unpack_buffers(packed_buffer: bytes) -> List[bytes]: """ Parameters ---------- - packed_buffers : packed buffer as string + packed_buffers : packed buffer as (byte?) string """ unpacked = [] while packed_buffer: @@ -146,11 +143,12 @@ def unpack_buffers(packed_buffer): return unpacked -def unpack_and_deserialize(packed_buffer): +def unpack_and_deserialize(packed_buffer: bytes) -> Any: """ Unpacks a packed buffer and returns the deserialized contents + Only works for packings of 3 buffers - presumably because of specific use of this call? Parameters ---------- - packed_buffers : packed buffer as string + packed_buffers : packed buffer as string... apparently expecting exactly 3 buffers? """ unpacked = [] while packed_buffer: diff --git a/parsl/tests/manual_tests/test_worker_count.py b/parsl/tests/manual_tests/test_worker_count.py index cf3a89f262..1a71fc552e 100644 --- a/parsl/tests/manual_tests/test_worker_count.py +++ b/parsl/tests/manual_tests/test_worker_count.py @@ -14,8 +14,10 @@ from parsl.tests.manual_tests.htex_local import config from parsl.executors import HighThroughputExecutor +from parsl.providers.base import ExecutionProvider assert isinstance(config.executors[0], HighThroughputExecutor) config.executors[0].cores_per_worker = CORES_PER_WORKER +assert isinstance(config.executors[0].provider, ExecutionProvider) config.executors[0].provider.init_blocks = 1 # from htex_midway import config diff --git a/parsl/tests/test_checkpointing/test_periodic.py b/parsl/tests/test_checkpointing/test_periodic.py index 607ca6cbef..b4c5fcdd09 100644 --- a/parsl/tests/test_checkpointing/test_periodic.py +++ b/parsl/tests/test_checkpointing/test_periodic.py @@ -7,6 +7,8 @@ from parsl.app.app import python_app from parsl.tests.configs.local_threads_checkpoint_periodic import config +dfk: parsl.DataFlowKernel + def local_setup(): global dfk diff --git a/parsl/tests/test_data/test_file_staging.py b/parsl/tests/test_data/test_file_staging.py index 0e9df68be0..c14c65d30e 100644 --- a/parsl/tests/test_data/test_file_staging.py +++ b/parsl/tests/test_data/test_file_staging.py @@ -47,7 +47,8 @@ def test_regression_200(): f.write("Hello World") fu = cat(inputs=[File("test.txt")], - outputs=[File("test_output.txt")]) + outputs=[File("test_output.txt")], + stdout='r200.out', stderr='r200.err') fu.result() fi = fu.outputs[0].result() diff --git a/parsl/tests/test_python_apps/test_basic.py b/parsl/tests/test_python_apps/test_basic.py index 3ebb41a902..b4bfb512d6 100644 --- a/parsl/tests/test_python_apps/test_basic.py +++ b/parsl/tests/test_python_apps/test_basic.py @@ -4,6 +4,9 @@ from parsl.app.app import python_app +# for mypy hacking... shouldn't stay around +import concurrent.futures + @python_app def double(x: float) -> float: @@ -11,9 +14,14 @@ def double(x: float) -> float: @python_app -def echo(x, string, stdout=None): +def echo(x_woo_woo_find_me: float, string: str, stdout=None) -> float: print(string) - return x * 5 + return x_woo_woo_find_me * 5 + + +def test_echo() -> None: + f = echo(x_woo_woo_find_me=3, string="hi") + f.result() @python_app diff --git a/parsl/tests/test_python_apps/test_simple.py b/parsl/tests/test_python_apps/test_simple.py index 007db7b923..351d525146 100644 --- a/parsl/tests/test_python_apps/test_simple.py +++ b/parsl/tests/test_python_apps/test_simple.py @@ -1,5 +1,9 @@ from parsl.app.app import python_app +from concurrent.futures import Future + +from typing import Dict, Union + @python_app def increment(x): @@ -14,19 +18,22 @@ def slow_increment(x, dur): def test_increment(depth=5): - futs = {0: 0} + futs = {0: 0} # type: Dict[int, Union[int, Future]] for i in range(1, depth): futs[i] = increment(futs[i - 1]) - x = sum([futs[i].result() for i in futs if not isinstance(futs[i], int)]) + # this is a slightly awkward rearrangement: we need to bind f so that mypy + # can take the type property proved by isinstance and carry it over to + # reason about if f.result() valid. + x = sum([f.result() for i in futs for f in [futs[i]] if isinstance(f, Future)]) assert x == sum(range(1, depth)), "[TEST] increment [FAILED]" def test_slow_increment(depth=5): - futs = {0: 0} + futs = {0: 0} # type: Dict[int, Union[int, Future]] for i in range(1, depth): futs[i] = slow_increment(futs[i - 1], 0.01) - x = sum([futs[i].result() for i in futs if not isinstance(futs[i], int)]) + x = sum([f.result() for i in futs for f in [futs[i]] if isinstance(f, Future)]) assert x == sum(range(1, depth)), "[TEST] slow_increment [FAILED]" diff --git a/parsl/usage_tracking/usage.py b/parsl/usage_tracking/usage.py index c2ed165e82..019832e00e 100644 --- a/parsl/usage_tracking/usage.py +++ b/parsl/usage_tracking/usage.py @@ -1,3 +1,4 @@ +from __future__ import annotations import uuid import time import hashlib @@ -9,19 +10,23 @@ import sys import platform +from typing import List + +from parsl.multiprocessing import forkProcess, ForkProcess from parsl.utils import setproctitle -from parsl.multiprocessing import ForkProcess from parsl.dataflow.states import States from parsl.version import VERSION as PARSL_VERSION +import parsl.dataflow.dflow # can't import just the symbol for DataFlowKernel because of mutually-recursive imports + logger = logging.getLogger(__name__) def async_process(fn): """ Decorator function to launch a function as a separate process """ - def run(*args, **kwargs): - proc = ForkProcess(target=fn, args=args, kwargs=kwargs, name="Usage-Tracking") + def run(*args, **kwargs) -> ForkProcess: + proc = forkProcess(target=fn, args=args, kwargs=kwargs, name="Usage-Tracking") proc.start() return proc @@ -90,8 +95,11 @@ class UsageTracker: """ - def __init__(self, dfk, ip='52.3.111.203', port=50077, - domain_name='tracking.parsl-project.org'): + def __init__(self, + dfk: parsl.dataflow.dflow.DataFlowKernel, + ip: str = '52.3.111.203', + port: int = 50077, + domain_name: str = 'tracking.parsl-project.org') -> None: """Initialize usage tracking unless the user has opted-out. We will try to resolve the hostname specified in kwarg:domain_name @@ -117,7 +125,7 @@ def __init__(self, dfk, ip='52.3.111.203', port=50077, self.sock_timeout = 5 self.UDP_PORT = port self.UDP_IP = None - self.procs = [] + self.procs: List[ForkProcess] = [] self.dfk = dfk self.config = self.dfk.config self.uuid = str(uuid.uuid4()) diff --git a/parsl/utils.py b/parsl/utils.py index fa83c03120..67622f835c 100644 --- a/parsl/utils.py +++ b/parsl/utils.py @@ -7,7 +7,8 @@ import time import typeguard from contextlib import contextmanager -from typing import List, Tuple, Union, Generator, IO, AnyStr, Dict +from typing import Callable, List, Sequence, Tuple, Union, Generator, IO, AnyStr, Dict, Any, Optional +from typing_extensions import Protocol, runtime_checkable import parsl from parsl.version import VERSION @@ -44,7 +45,7 @@ def get_version() -> str: @typeguard.typechecked -def get_all_checkpoints(rundir: str = "runinfo") -> List[str]: +def get_all_checkpoints(rundir: str = "runinfo") -> Sequence[str]: """Finds the checkpoints from all runs in the rundir. Kwargs: @@ -73,7 +74,7 @@ def get_all_checkpoints(rundir: str = "runinfo") -> List[str]: @typeguard.typechecked -def get_last_checkpoint(rundir: str = "runinfo") -> List[str]: +def get_last_checkpoint(rundir: str = "runinfo") -> Sequence[str]: """Finds the checkpoint from the last run, if one exists. Note that checkpoints are incremental, and this helper will not find @@ -159,6 +160,11 @@ def wtime_to_minutes(time_string: str) -> int: return total_mins +@runtime_checkable +class IsWrapper(Protocol): + __wrapped__: Callable + + class RepresentationMixin: """A mixin class for adding a __repr__ method. @@ -185,8 +191,12 @@ def __init__(self, first, second, third='three', fourth='fourth'): """ __max_width__ = 80 + # vs PR 1846: this has a type: ignore where I have more invasively changed the code to + # use type(self).__init__ and not checked if that works def __repr__(self) -> str: init = self.__init__ # type: ignore[misc] + # init: Any # to override something I don't understand with myppy vs the init, iswrapper test below + # init = type(self).__init__ # does this change from self.__init__ work? # This test looks for a single layer of wrapping performed by # functools.update_wrapper, commonly used in decorators. This will @@ -196,7 +206,7 @@ def __repr__(self) -> str: # decorators, or cope with other decorators which do not use # functools.update_wrapper. - if hasattr(init, '__wrapped__'): + if isinstance(init, IsWrapper): init = init.__wrapped__ argspec = inspect.getfullargspec(init) @@ -284,7 +294,8 @@ class Timer: """ - def __init__(self, callback, *args, interval=5, name=None): + # TODO: some kind of dependentish type here? eg Callable[X] and args has type X? + def __init__(self, callback: Callable, *args: Tuple[Any, ...], interval: float = 5, name: Optional[str] = None) -> None: """Initialize the Timer object We start the timer thread here diff --git a/requirements.txt b/requirements.txt index 07cdba8ba5..d1e94b2417 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ pyzmq>=17.1.2 typeguard>=2.10,<3 typing-extensions +globus-sdk<3 types-paramiko types-requests types-six six -globus-sdk dill tblib requests