diff --git a/parsl/executors/high_throughput/errors.py b/parsl/executors/high_throughput/errors.py index 002c125c73..4db7907523 100644 --- a/parsl/executors/high_throughput/errors.py +++ b/parsl/executors/high_throughput/errors.py @@ -10,3 +10,13 @@ def __repr__(self): def __str__(self): return self.__repr__() + + +class CommandClientTimeoutError(Exception): + """Raised when the command client times out waiting for a response. + """ + + +class CommandClientBadError(Exception): + """Raised when the command client is bad from an earlier timeout. + """ diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py index 72c1422d4d..8a730c1c38 100644 --- a/parsl/executors/high_throughput/zmq_pipes.py +++ b/parsl/executors/high_throughput/zmq_pipes.py @@ -3,8 +3,11 @@ import zmq import logging import threading +import time from parsl import curvezmq +from parsl.errors import InternalConsistencyError +from parsl.executors.high_throughput.errors import CommandClientBadError, CommandClientTimeoutError logger = logging.getLogger(__name__) @@ -31,6 +34,7 @@ def __init__(self, zmq_context: curvezmq.ClientContext, ip_address, port_range): self.port = None self.create_socket_and_bind() self._lock = threading.Lock() + self.ok = True def create_socket_and_bind(self): """ Creates socket and binds to a port. @@ -46,7 +50,7 @@ def create_socket_and_bind(self): else: self.zmq_socket.bind("tcp://{}:{}".format(self.ip_address, self.port)) - def run(self, message, max_retries=3): + def run(self, message, max_retries=3, timeout_s=None): """ This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. @@ -54,13 +58,43 @@ def run(self, message, max_retries=3): in ZMQ sockets reaching a broken state once there are ~10k tasks in flight. This issue can be magnified if each the serialized buffer itself is larger. """ + if not self.ok: + raise CommandClientBadError() + + start_time_s = time.monotonic() + reply = '__PARSL_ZMQ_PIPES_MAGIC__' with self._lock: for _ in range(max_retries): try: logger.debug("Sending command client command") + + if timeout_s is not None: + remaining_time_s = start_time_s + timeout_s - time.monotonic() + poll_result = self.zmq_socket.poll(timeout=remaining_time_s * 1000, flags=zmq.POLLOUT) + if poll_result == zmq.POLLOUT: + pass # this is OK, so continue + elif poll_result == 0: + raise CommandClientTimeoutError("Waiting for command channel to be ready for a command") + else: + raise InternalConsistencyError(f"ZMQ poll returned unexpected value: {poll_result}") + self.zmq_socket.send_pyobj(message, copy=True) - logger.debug("Waiting for command client response") + + if timeout_s is not None: + logger.debug("Polling for command client response or timeout") + remaining_time_s = start_time_s + timeout_s - time.monotonic() + poll_result = self.zmq_socket.poll(timeout=remaining_time_s * 1000, flags=zmq.POLLIN) + if poll_result == zmq.POLLIN: + pass # this is OK, so continue + elif poll_result == 0: + logger.error("Command timed-out - command client is now bad forever") + self.ok = False + raise CommandClientTimeoutError("Waiting for a reply from command channel") + else: + raise InternalConsistencyError(f"ZMQ poll returned unexpected value: {poll_result}") + + logger.debug("Receiving command client response") reply = self.zmq_socket.recv_pyobj() logger.debug("Received command client response") except zmq.ZMQError: diff --git a/parsl/tests/test_htex/test_command_client_timeout.py b/parsl/tests/test_htex/test_command_client_timeout.py new file mode 100644 index 0000000000..4abdf991bc --- /dev/null +++ b/parsl/tests/test_htex/test_command_client_timeout.py @@ -0,0 +1,69 @@ +import pytest +import threading +import time +import zmq +from parsl import curvezmq +from parsl.executors.high_throughput.zmq_pipes import CommandClient +from parsl.executors.high_throughput.errors import CommandClientTimeoutError, CommandClientBadError + + +# Time constant used for timeout tests: various delays and +# timeouts will be appropriate multiples of this, but the +# value of T itself should not matter too much as long as +# it is big enough for zmq connections to happen successfully. +T = 0.25 + + +@pytest.mark.local +def test_command_not_sent() -> None: + """Tests timeout on command send. + """ + ctx = curvezmq.ClientContext(None) + + # RFC6335 ephemeral port range + cc = CommandClient(ctx, "127.0.0.1", (49152, 65535)) + + # cc will now wait for a connection, but we won't do anything to make the + # other side of the connection exist, so any command given to cc should + # timeout. + + with pytest.raises(CommandClientTimeoutError): + cc.run("SOMECOMMAND", timeout_s=T) + + cc.close() + + +@pytest.mark.local +def test_command_ignored() -> None: + """Tests timeout on command response. + Tests that we timeout after a response and that the command client + sets itself into a bad state. + + This only tests sequential access to the command client, even though + htex makes multithreaded use of the command client: see issue #3376 about + that lack of thread safety. + """ + ctx = curvezmq.ClientContext(None) + + # RFC6335 ephemeral port range + cc = CommandClient(ctx, "127.0.0.1", (49152, 65535)) + + ic_ctx = curvezmq.ServerContext(None) + ic_channel = ic_ctx.socket(zmq.REP) + ic_channel.connect(f"tcp://127.0.0.1:{cc.port}") + + with pytest.raises(CommandClientTimeoutError): + cc.run("SLOW_COMMAND", timeout_s=T) + + req = ic_channel.recv_pyobj() + assert req == "SLOW_COMMAND", "Should have received command on interchange side" + assert not cc.ok, "CommandClient should have set itself to bad" + + with pytest.raises(CommandClientBadError): + cc.run("ANOTHER_COMMAND") + + cc.close() + ctx.term() + + ic_channel.close() + ic_ctx.term()