Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require type annotations throughout the interchange #2940

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ disallow_untyped_defs = True
disallow_any_expr = True

[mypy-parsl.executors.high_throughput.interchange.*]
check_untyped_defs = True
disallow_untyped_defs = True
warn_unreachable = True

[mypy-parsl.monitoring.*]
disallow_untyped_decorators = True
Expand Down
67 changes: 37 additions & 30 deletions parsl/executors/high_throughput/interchange.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
import multiprocessing
import zmq
import os
import sys
Expand All @@ -13,7 +14,7 @@
import threading
import json

from typing import cast, Any, Dict, Set, Optional
from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple

from parsl.utils import setproctitle
from parsl.version import VERSION as PARSL_VERSION
Expand All @@ -36,23 +37,23 @@ class ManagerLost(Exception):
''' Task lost due to manager loss. Manager is considered lost when multiple heartbeats
have been missed.
'''
def __init__(self, manager_id, hostname):
def __init__(self, manager_id: bytes, hostname: str) -> None:
self.manager_id = manager_id
self.tstamp = time.time()
self.hostname = hostname

def __str__(self):
def __str__(self) -> str:
return "Task failure due to loss of manager {} on host {}".format(self.manager_id.decode(), self.hostname)


class VersionMismatch(Exception):
''' Manager and Interchange versions do not match
'''
def __init__(self, interchange_version, manager_version):
def __init__(self, interchange_version: str, manager_version: str):
self.interchange_version = interchange_version
self.manager_version = manager_version

def __str__(self):
def __str__(self) -> str:
return "Manager version info {} does not match interchange version info {}, causing a critical failure".format(
self.manager_version,
self.interchange_version)
Expand All @@ -67,17 +68,17 @@ class Interchange:
4. Service single and batch requests from workers
"""
def __init__(self,
client_address="127.0.0.1",
client_address: str = "127.0.0.1",
interchange_address: Optional[str] = None,
client_ports=(50055, 50056, 50057),
worker_ports=None,
worker_port_range=(54000, 55000),
hub_address=None,
hub_port=None,
heartbeat_threshold=60,
logdir=".",
logging_level=logging.INFO,
poll_period=10,
client_ports: Tuple[int, int, int] = (50055, 50056, 50057),
worker_ports: Optional[Tuple[int, int]] = None,
worker_port_range: Tuple[int, int] = (54000, 55000),
hub_address: Optional[str] = None,
hub_port: Optional[int] = None,
heartbeat_threshold: int = 60,
logdir: str = ".",
logging_level: int = logging.INFO,
poll_period: int = 10,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(self,

logger.info("Platform info: {}".format(self.current_platform))

def get_tasks(self, count):
def get_tasks(self, count: int) -> Sequence[dict]:
""" Obtains a batch of tasks from the internal pending_task_queue

Parameters
Expand All @@ -216,7 +217,7 @@ def get_tasks(self, count):
return tasks

@wrap_with_logs(target="interchange")
def task_puller(self):
def task_puller(self) -> NoReturn:
"""Pull tasks from the incoming tasks zmq pipe onto the internal
pending task queue
"""
Expand All @@ -237,7 +238,7 @@ def task_puller(self):
task_counter += 1
logger.debug(f"Fetched {task_counter} tasks so far")

def _create_monitoring_channel(self):
def _create_monitoring_channel(self) -> Optional[zmq.Socket]:
if self.hub_address and self.hub_port:
logger.info("Connecting to monitoring")
hub_channel = self.context.socket(zmq.DEALER)
Expand All @@ -248,7 +249,7 @@ def _create_monitoring_channel(self):
else:
return None

def _send_monitoring_info(self, hub_channel, manager: ManagerRecord):
def _send_monitoring_info(self, hub_channel: Optional[zmq.Socket], manager: ManagerRecord) -> None:
if hub_channel:
logger.info("Sending message {} to hub".format(manager))

Expand All @@ -259,7 +260,7 @@ def _send_monitoring_info(self, hub_channel, manager: ManagerRecord):
hub_channel.send_pyobj((MessageType.NODE_INFO, d))

@wrap_with_logs(target="interchange")
def _command_server(self):
def _command_server(self) -> NoReturn:
""" Command server to run async command to the interchange
"""
logger.debug("Command Server Starting")
Expand Down Expand Up @@ -326,7 +327,7 @@ def _command_server(self):
continue

@wrap_with_logs
def start(self):
def start(self) -> None:
""" Start the interchange
"""

Expand Down Expand Up @@ -382,7 +383,7 @@ def start(self):
logger.info("Processed {} tasks in {} seconds".format(self.count, delta))
logger.warning("Exiting")

def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill_event):
def process_task_outgoing_incoming(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket], kill_event: threading.Event) -> None:
# Listen for requests for work
if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN:
logger.debug("starting task_outgoing section")
Expand Down Expand Up @@ -448,7 +449,7 @@ def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill
logger.error("Unexpected non-heartbeat message received from manager {}")
logger.debug("leaving task_outgoing section")

def process_tasks_to_send(self, interesting_managers):
def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None:
# If we had received any requests, check if there are tasks that could be passed

logger.debug("Managers count (interesting/total): {interesting}/{total}".format(
Expand All @@ -474,14 +475,14 @@ def process_tasks_to_send(self, interesting_managers):
tids = [t['task_id'] for t in tasks]
m['tasks'].extend(tids)
m['idle_since'] = None
logger.debug("Sent tasks: {} to manager {}".format(tids, manager_id))
logger.debug("Sent tasks: {} to manager {!r}".format(tids, manager_id))
# recompute real_capacity after sending tasks
real_capacity = m['max_capacity'] - tasks_inflight
if real_capacity > 0:
logger.debug("Manager {} has free capacity {}".format(manager_id, real_capacity))
logger.debug("Manager {!r} has free capacity {}".format(manager_id, real_capacity))
# ... so keep it in the interesting_managers list
else:
logger.debug("Manager {} is now saturated".format(manager_id))
logger.debug("Manager {!r} is now saturated".format(manager_id))
interesting_managers.remove(manager_id)
else:
interesting_managers.remove(manager_id)
Expand All @@ -490,7 +491,7 @@ def process_tasks_to_send(self, interesting_managers):
else:
logger.debug("either no interesting managers or no tasks, so skipping manager pass")

def process_results_incoming(self, interesting_managers, hub_channel):
def process_results_incoming(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None:
# Receive any results and forward to client
if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN:
logger.debug("entering results_incoming section")
Expand All @@ -508,6 +509,12 @@ def process_results_incoming(self, interesting_managers, hub_channel):
# process this for task ID and forward to executor
b_messages.append((p_message, r))
elif r['type'] == 'monitoring':
# the monitoring code makes the assumption that no
# monitoring messages will be received if monitoring
# is not configured, and that hub_channel will only
# be None when monitoring is not configurated.
assert hub_channel is not None

hub_channel.send_pyobj(r['payload'])
elif r['type'] == 'heartbeat':
logger.debug(f"Manager {manager_id!r} sent heartbeat via results connection")
Expand Down Expand Up @@ -552,7 +559,7 @@ def process_results_incoming(self, interesting_managers, hub_channel):
interesting_managers.add(manager_id)
logger.debug("leaving results_incoming section")

def expire_bad_managers(self, interesting_managers, hub_channel):
def expire_bad_managers(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None:
bad_managers = [(manager_id, m) for (manager_id, m) in self._ready_managers.items() if
time.time() - m['last_heartbeat'] > self.heartbeat_threshold]
for (manager_id, m) in bad_managers:
Expand All @@ -576,7 +583,7 @@ def expire_bad_managers(self, interesting_managers, hub_channel):
interesting_managers.remove(manager_id)


def start_file_logger(filename, level=logging.DEBUG, format_string=None):
def start_file_logger(filename: str, level: int = logging.DEBUG, format_string: Optional[str] = None) -> None:
"""Add a stream log handler.

Parameters
Expand Down Expand Up @@ -608,7 +615,7 @@ def start_file_logger(filename, level=logging.DEBUG, format_string=None):


@wrap_with_logs(target="interchange")
def starter(comm_q, *args, **kwargs):
def starter(comm_q: multiprocessing.Queue, *args: Any, **kwargs: Any) -> None:
"""Start the interchange process

The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__
Expand Down
Loading