Skip to content

Commit

Permalink
Implementing router-dealer pattern with custom acknowledgments with zmq
Browse files Browse the repository at this point in the history
 - dealers always wait for acknowledgment from the evaluator
 - removing websockets and wait_for_evaluator
 - Settup encryption with curve
 - each dealer (client, dispatcher) will get a unique name
 - Monitor is an advanced version Client
 - _server_started.wait() is to signal that zmq router socket is bound
 - Use TCP protocol only when using LSF, SLURM or TORQUE queues
 -- Use ipc_protocol when using LOCAL driver
 - Remove certificate
 - Remove synced _send from Client
 - Remove cert generator
 - Remove ClientConnectionClosedOK
 - Add test for new connection while closing down evaluator
 - Add test for handle dispatcher and dispatcher messages in evaluator
 - Add tests for ipc and tcp ee config
 - Add test for clear connect and disconnect of Monitor
 - Set a a correct protocol for everestserver
  • Loading branch information
xjules committed Dec 20, 2024
1 parent 2a4b6be commit 0f1bb30
Show file tree
Hide file tree
Showing 32 changed files with 762 additions and 1,192 deletions.
1 change: 0 additions & 1 deletion docs/ert/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
("py:class", "pydantic.types.PositiveInt"),
("py:class", "LibresFacade"),
("py:class", "pandas.core.frame.DataFrame"),
("py:class", "websockets.server.WebSocketServerProtocol"),
("py:class", "EnsembleReader"),
]
nitpick_ignore_regex = [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"python-dateutil",
"python-multipart", # extra dependency for fastapi
"pyyaml",
"pyzmq",
"qtpy",
"requests",
"resfo",
Expand All @@ -68,7 +69,6 @@ dependencies = [
"tqdm>=4.62.0",
"typing_extensions>=4.5",
"uvicorn >= 0.17.0",
"websockets",
"xarray",
"xtgeo >= 3.3.0",
]
Expand Down
9 changes: 1 addition & 8 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ def _setup_reporters(
ens_id,
dispatch_url,
ee_token=None,
ee_cert_path=None,
experiment_id=None,
) -> list[reporting.Reporter]:
reporters: list[reporting.Reporter] = []
if is_interactive_run:
reporters.append(reporting.Interactive())
elif ens_id and experiment_id is None:
reporters.append(reporting.File())
reporters.append(
reporting.Event(
evaluator_url=dispatch_url, token=ee_token, cert_path=ee_cert_path
)
)
reporters.append(reporting.Event(evaluator_url=dispatch_url, token=ee_token))
else:
reporters.append(reporting.File())
return reporters
Expand Down Expand Up @@ -123,7 +118,6 @@ def main(args):
experiment_id = jobs_data.get("experiment_id")
ens_id = jobs_data.get("ens_id")
ee_token = jobs_data.get("ee_token")
ee_cert_path = jobs_data.get("ee_cert_path")
dispatch_url = jobs_data.get("dispatch_url")

is_interactive_run = len(parsed_args.job) > 0
Expand All @@ -132,7 +126,6 @@ def main(args):
ens_id,
dispatch_url,
ee_token,
ee_cert_path,
experiment_id,
)

Expand Down
229 changes: 126 additions & 103 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import asyncio
import logging
import ssl
from typing import Any, AnyStr, Self

from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
ConnectionClosedOK,
InvalidHandshake,
InvalidURI,
)
import uuid
from abc import abstractmethod
from typing import Any, Self

from _ert.async_utils import new_event_loop
import zmq
import zmq.asyncio

logger = logging.getLogger(__name__)

Expand All @@ -21,112 +16,140 @@ class ClientConnectionError(Exception):
pass


class ClientConnectionClosedOK(Exception):
pass
CONNECT_MSG = b"CONNECT"
DISCONNECT_MSG = b"DISCONNECT"
ACK_MSG = b"ACK"


class Client:
DEFAULT_MAX_RETRIES = 10
DEFAULT_TIMEOUT_MULTIPLIER = 5
CONNECTION_TIMEOUT = 60

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if self.websocket is not None:
self.loop.run_until_complete(self.websocket.close())
self.loop.close()

async def __aenter__(self) -> "Client":
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
if self.websocket is not None:
await self.websocket.close()
DEFAULT_ACK_TIMEOUT = 5

def __init__(
self,
url: str,
token: str | None = None,
cert: str | bytes | None = None,
max_retries: int | None = None,
timeout_multiplier: int | None = None,
dealer_name: str | None = None,
ack_timeout: float | None = None,
) -> None:
if max_retries is None:
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT
self.url = url
self.token = token
self._additional_headers = Headers()

self._ack_event: asyncio.Event = asyncio.Event()
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
# this is to avoid blocking the event loop when closing the socket
# wherein the linger is set to 0 to discard all messages in the queue
self.socket.setsockopt(zmq.LINGER, 0)
self.dealer_id = dealer_name or f"dispatch-{uuid.uuid4().hex[:8]}"
self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id)

if token is not None:
self._additional_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
self._ssl_context: bool | ssl.SSLContext | None = None
if cert is not None:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.load_verify_locations(cadata=cert)
elif url.startswith("wss"):
self._ssl_context = True

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: ClientConnection | None = None
self.loop = new_event_loop()

async def get_websocket(self) -> ClientConnection:
return await connect(
self.url,
ssl=self._ssl_context,
additional_headers=self._additional_headers,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")

self._receiver_task: asyncio.Task[None] | None = None

async def __aenter__(self) -> Self:
await self.connect()
return self

async def _send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
try:
await self.send(DISCONNECT_MSG)
except ClientConnectionError:
logger.error("No ack for dealer disconnection. Connection is down!")
finally:
self.socket.disconnect(self.url)
await self._term_receiver_task()
self.term()

def term(self) -> None:
self.socket.close()
self.context.term()

async def _term_receiver_task(self) -> None:
if self._receiver_task and not self._receiver_task.done():
self._receiver_task.cancel()
await asyncio.gather(self._receiver_task, return_exceptions=True)
self._receiver_task = None

async def connect(self) -> None:
self.socket.connect(self.url)
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self.send(CONNECT_MSG, retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise

@abstractmethod
async def process_message(self, msg: str) -> None:
"""
This method is implemented in the Monitor, which stores the messages in a queue.
Args:
msg (str): Message (event) to be processed
"""

async def _receiver(self) -> None:
while True:
try:
if self.websocket is None:
self.websocket = await self.get_websocket()
await self.websocket.send(msg)
return
except ConnectionClosedOK as exception:
error_msg = (
f"Connection closed received from the server {self.url}! "
f" Exception from {type(exception)}: {exception!s}"
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == ACK_MSG:
self._ack_event.set()
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
raise ClientConnectionClosedOK(error_msg) from exception
except (TimeoutError, InvalidHandshake, InvalidURI, OSError) as exception:
if retry == self._max_retries:
error_msg = (
f"Not able to establish the "
f"websocket connection {self.url}! Max retries reached!"
" Check for firewall issues."
f" Exception from {type(exception)}: {exception!s}"
await asyncio.sleep(0)
self.socket.connect(self.url)

async def send(self, message: str | bytes, retries: int | None = None) -> None:
self._ack_event.clear()

if isinstance(message, str):
message = message.encode("utf-8")

backoff = 1
if retries is None:
retries = self.DEFAULT_MAX_RETRIES
while retries >= 0:
try:
await self.socket.send_multipart([b"", message])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._ack_timeout
)
raise ClientConnectionError(error_msg) from exception
except ConnectionClosedError as exception:
if retry == self._max_retries:
error_msg = (
f"Not been able to send the event"
f" to {self.url}! Max retries reached!"
f" Exception from {type(exception)}: {exception!s}"
return
except TimeoutError:
logger.warning(
f"{self.dealer_id} failed to get acknowledgment on the {message!r}. Resending."
)
raise ClientConnectionError(error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
except asyncio.CancelledError:
self.term()
raise

retries -= 1
if retries > 0:
logger.info(f"Retrying... ({retries} attempts left)")
await asyncio.sleep(backoff)
# this call is idempotent
self.socket.connect(self.url)
backoff = min(backoff * 2, 10) # Exponential backoff
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message!r} after retries!"
)
Loading

0 comments on commit 0f1bb30

Please sign in to comment.