Skip to content

Commit

Permalink
Reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 20, 2024
1 parent dec2341 commit 9a41bf8
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ class ClientConnectionError(Exception):
pass


CONNECT_MSG = "CONNECT"
DISCONNECT_MSG = "DISCONNECT"
CONNECT_MSG = b"CONNECT"
DISCONNECT_MSG = b"DISCONNECT"
ACK_MSG = b"ACK"


class Client:
DEFAULT_MAX_RETRIES = 5
DEFAULT_ACK_TIMEOUT = 5
_receiver_task: asyncio.Task[None] | None

def __init__(
self,
Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")

self._receiver_task = None
self._receiver_task: asyncio.Task[None] | None = None

async def __aenter__(self) -> Self:
await self.connect()
Expand Down Expand Up @@ -108,15 +107,18 @@ async def _receiver(self) -> None:
await asyncio.sleep(1)
self.socket.connect(self.url)

async def send(self, message: str, retries: int | None = None) -> None:
async def send(self, message: str | bytes, retries: int | None = None) -> None:
self._ack_event.clear()

if isinstance(message, str):
message = message.encode("utf-8")

backoff = 1
if retries is None:
retries = self.DEFAULT_MAX_RETRIES
while retries >= 0:
try:
await self.socket.send_multipart([b"", message.encode("utf-8")])
await self.socket.send_multipart([b"", message])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._ack_timeout
Expand Down
14 changes: 6 additions & 8 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def ensemble(self) -> Ensemble:
return self._ensemble

async def handle_client(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == CONNECT_MSG:
if frame == CONNECT_MSG:
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
Expand All @@ -209,12 +208,12 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:
await self._router_socket.send_multipart(
[dealer, b"", event_to_json(event).encode("utf-8")]
)
elif raw_msg == DISCONNECT_MSG:
elif frame == DISCONNECT_MSG:
self._clients_connected.discard(dealer)
if not self._clients_connected:
self._clients_empty.set()
else:
event = event_from_json(raw_msg)
event = event_from_json(frame.decode("utf-8"))
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
Expand All @@ -223,16 +222,15 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:
self.stop()

async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == CONNECT_MSG:
if frame == CONNECT_MSG:
self._dispatchers_connected.add(dealer)
self._dispatchers_empty.clear()
elif raw_msg == DISCONNECT_MSG:
elif frame == DISCONNECT_MSG:
self._dispatchers_connected.discard(dealer)
if not self._dispatchers_connected:
self._dispatchers_empty.set()
else:
event = dispatch_event_from_json(raw_msg)
event = dispatch_event_from_json(frame.decode("utf-8"))
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ async def test_evaluator_handles_dispatchers_connected(
):
evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config())

await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG)
await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG)
assert not evaluator._dispatchers_empty.is_set()
assert evaluator._dispatchers_connected == {b"dispatcher-1", b"dispatcher-2"}
await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG)
await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG)
assert evaluator._dispatchers_empty.is_set()


Expand Down
8 changes: 2 additions & 6 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ async def mock_event_handler(router_socket):
while True:
dealer, _, frame = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", ACK_MSG])
frame = frame.decode("utf-8")
messages.append((dealer.decode("utf-8"), frame))
if frame == DISCONNECT_MSG:
break
Expand Down Expand Up @@ -76,14 +75,13 @@ async def mock_event_handler(router_socket):
dealer, _, frame = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", ACK_MSG])
dealer = dealer.decode("utf-8")
frame = frame.decode("utf-8")
if frame == CONNECT_MSG:
connected = True
elif frame == DISCONNECT_MSG:
connected = False
return
else:
event = event_from_json(frame)
event = event_from_json(frame.decode("utf-8"))
assert connected
assert type(event) is EEUserDone

Expand Down Expand Up @@ -137,15 +135,13 @@ async def mock_event_handler(router_socket):
while True:
dealer, _, frame = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", ACK_MSG])
dealer = dealer.decode("utf-8")
frame = frame.decode("utf-8")
if frame == CONNECT_MSG:
connected = True
elif frame == DISCONNECT_MSG:
connected = False
return
else:
event = event_from_json(frame)
event = event_from_json(frame.decode("utf-8"))
assert connected
assert type(event) is EEUserCancel

Expand Down
3 changes: 1 addition & 2 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,10 @@ async def _handler(self):
while True:
try:
dealer, __, frame = await self.router_socket.recv_multipart()
frame = frame.decode("utf-8")
if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1:
self.messages.append(frame)
self.messages.append(frame.decode("utf-8"))
except asyncio.CancelledError:
break

Expand Down

0 comments on commit 9a41bf8

Please sign in to comment.