From 9a41bf8ffc5acc744fbbaf752d6fe1e4a4cc8761 Mon Sep 17 00:00:00 2001 From: xjules Date: Fri, 20 Dec 2024 09:46:47 +0100 Subject: [PATCH] Reviewer comments --- src/_ert/forward_model_runner/client.py | 14 ++++++++------ src/ert/ensemble_evaluator/evaluator.py | 14 ++++++-------- .../ensemble_evaluator/test_ensemble_evaluator.py | 8 ++++---- .../unit_tests/ensemble_evaluator/test_monitor.py | 8 ++------ tests/ert/utils.py | 3 +-- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 56b2cc9cb37..9e16e7b42b2 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -15,15 +15,14 @@ class ClientConnectionError(Exception): pass -CONNECT_MSG = "CONNECT" -DISCONNECT_MSG = "DISCONNECT" +CONNECT_MSG = b"CONNECT" +DISCONNECT_MSG = b"DISCONNECT" ACK_MSG = b"ACK" class Client: DEFAULT_MAX_RETRIES = 5 DEFAULT_ACK_TIMEOUT = 5 - _receiver_task: asyncio.Task[None] | None def __init__( self, @@ -51,7 +50,7 @@ def __init__( self.socket.curve_publickey = client_public self.socket.curve_serverkey = token.encode("utf-8") - self._receiver_task = None + self._receiver_task: asyncio.Task[None] | None = None async def __aenter__(self) -> Self: await self.connect() @@ -108,15 +107,18 @@ async def _receiver(self) -> None: await asyncio.sleep(1) self.socket.connect(self.url) - async def send(self, message: str, retries: int | None = None) -> None: + async def send(self, message: str | bytes, retries: int | None = None) -> None: self._ack_event.clear() + if isinstance(message, str): + message = message.encode("utf-8") + backoff = 1 if retries is None: retries = self.DEFAULT_MAX_RETRIES while retries >= 0: try: - await self.socket.send_multipart([b"", message.encode("utf-8")]) + await self.socket.send_multipart([b"", message]) try: await asyncio.wait_for( self._ack_event.wait(), timeout=self._ack_timeout diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index fc1a3e11556..dd421db19e7 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -197,8 +197,7 @@ def ensemble(self) -> Ensemble: return self._ensemble async def handle_client(self, dealer: bytes, frame: bytes) -> None: - raw_msg = frame.decode("utf-8") - if raw_msg == CONNECT_MSG: + if frame == CONNECT_MSG: self._clients_connected.add(dealer) self._clients_empty.clear() current_snapshot_dict = self._ensemble.snapshot.to_dict() @@ -209,12 +208,12 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None: await self._router_socket.send_multipart( [dealer, b"", event_to_json(event).encode("utf-8")] ) - elif raw_msg == DISCONNECT_MSG: + elif frame == DISCONNECT_MSG: self._clients_connected.discard(dealer) if not self._clients_connected: self._clients_empty.set() else: - event = event_from_json(raw_msg) + event = event_from_json(frame.decode("utf-8")) if type(event) is EEUserCancel: logger.debug("Client asked to cancel.") self._signal_cancel() @@ -223,16 +222,15 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None: self.stop() async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None: - raw_msg = frame.decode("utf-8") - if raw_msg == CONNECT_MSG: + if frame == CONNECT_MSG: self._dispatchers_connected.add(dealer) self._dispatchers_empty.clear() - elif raw_msg == DISCONNECT_MSG: + elif frame == DISCONNECT_MSG: self._dispatchers_connected.discard(dealer) if not self._dispatchers_connected: self._dispatchers_empty.set() else: - event = dispatch_event_from_json(raw_msg) + event = dispatch_event_from_json(frame.decode("utf-8")) if event.ensemble != self.ensemble.id_: logger.info( "Got event from evaluator " diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index 8799d62c515..3959beaa273 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -81,12 +81,12 @@ async def test_evaluator_handles_dispatchers_connected( ): evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config()) - await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG.encode("utf-8")) - await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG.encode("utf-8")) + await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG) + await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG) assert not evaluator._dispatchers_empty.is_set() assert evaluator._dispatchers_connected == {b"dispatcher-1", b"dispatcher-2"} - await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG.encode("utf-8")) - await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG.encode("utf-8")) + await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG) + await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG) assert evaluator._dispatchers_empty.is_set() diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index c3b58a6be3c..4fc083fbba0 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -37,7 +37,6 @@ async def mock_event_handler(router_socket): while True: dealer, _, frame = await router_socket.recv_multipart() await router_socket.send_multipart([dealer, b"", ACK_MSG]) - frame = frame.decode("utf-8") messages.append((dealer.decode("utf-8"), frame)) if frame == DISCONNECT_MSG: break @@ -76,14 +75,13 @@ async def mock_event_handler(router_socket): dealer, _, frame = await router_socket.recv_multipart() await router_socket.send_multipart([dealer, b"", ACK_MSG]) dealer = dealer.decode("utf-8") - frame = frame.decode("utf-8") if frame == CONNECT_MSG: connected = True elif frame == DISCONNECT_MSG: connected = False return else: - event = event_from_json(frame) + event = event_from_json(frame.decode("utf-8")) assert connected assert type(event) is EEUserDone @@ -137,15 +135,13 @@ async def mock_event_handler(router_socket): while True: dealer, _, frame = await router_socket.recv_multipart() await router_socket.send_multipart([dealer, b"", ACK_MSG]) - dealer = dealer.decode("utf-8") - frame = frame.decode("utf-8") if frame == CONNECT_MSG: connected = True elif frame == DISCONNECT_MSG: connected = False return else: - event = event_from_json(frame) + event = event_from_json(frame.decode("utf-8")) assert connected assert type(event) is EEUserCancel diff --git a/tests/ert/utils.py b/tests/ert/utils.py index bb79072f93c..1ea358f783c 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -115,11 +115,10 @@ async def _handler(self): while True: try: dealer, __, frame = await self.router_socket.recv_multipart() - frame = frame.decode("utf-8") if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0: await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1: - self.messages.append(frame) + self.messages.append(frame.decode("utf-8")) except asyncio.CancelledError: break