Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace websockets with zmq #9173

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/ert/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
("py:class", "pydantic.types.PositiveInt"),
("py:class", "LibresFacade"),
("py:class", "pandas.core.frame.DataFrame"),
("py:class", "websockets.server.WebSocketServerProtocol"),
("py:class", "EnsembleReader"),
]
nitpick_ignore_regex = [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"python-dateutil",
"python-multipart", # extra dependency for fastapi
"pyyaml",
"pyzmq",
"qtpy",
"requests",
"resfo",
Expand All @@ -68,7 +69,6 @@ dependencies = [
"tqdm>=4.62.0",
"typing_extensions>=4.5",
"uvicorn >= 0.17.0",
"websockets",
"xarray",
"xtgeo >= 3.3.0",
]
Expand Down
9 changes: 1 addition & 8 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ def _setup_reporters(
ens_id,
dispatch_url,
ee_token=None,
ee_cert_path=None,
experiment_id=None,
) -> list[reporting.Reporter]:
reporters: list[reporting.Reporter] = []
if is_interactive_run:
reporters.append(reporting.Interactive())
elif ens_id and experiment_id is None:
reporters.append(reporting.File())
reporters.append(
reporting.Event(
evaluator_url=dispatch_url, token=ee_token, cert_path=ee_cert_path
)
)
reporters.append(reporting.Event(evaluator_url=dispatch_url, token=ee_token))
else:
reporters.append(reporting.File())
return reporters
Expand Down Expand Up @@ -123,7 +118,6 @@ def main(args):
experiment_id = jobs_data.get("experiment_id")
ens_id = jobs_data.get("ens_id")
ee_token = jobs_data.get("ee_token")
ee_cert_path = jobs_data.get("ee_cert_path")
dispatch_url = jobs_data.get("dispatch_url")

is_interactive_run = len(parsed_args.job) > 0
Expand All @@ -132,7 +126,6 @@ def main(args):
ens_id,
dispatch_url,
ee_token,
ee_cert_path,
experiment_id,
)

Expand Down
222 changes: 119 additions & 103 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from __future__ import annotations
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved

import asyncio
import logging
import ssl
from typing import Any, AnyStr, Self

from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
ConnectionClosedOK,
InvalidHandshake,
InvalidURI,
)
import uuid
from typing import Any, Self

from _ert.async_utils import new_event_loop
import zmq
import zmq.asyncio

logger = logging.getLogger(__name__)

Expand All @@ -21,112 +15,134 @@ class ClientConnectionError(Exception):
pass


class ClientConnectionClosedOK(Exception):
pass
CONNECT_MSG = b"CONNECT"
DISCONNECT_MSG = b"DISCONNECT"
ACK_MSG = b"ACK"


class Client:
DEFAULT_MAX_RETRIES = 10
DEFAULT_TIMEOUT_MULTIPLIER = 5
CONNECTION_TIMEOUT = 60

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if self.websocket is not None:
self.loop.run_until_complete(self.websocket.close())
self.loop.close()

async def __aenter__(self) -> "Client":
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
if self.websocket is not None:
await self.websocket.close()
DEFAULT_ACK_TIMEOUT = 5

def __init__(
self,
url: str,
token: str | None = None,
cert: str | bytes | None = None,
max_retries: int | None = None,
timeout_multiplier: int | None = None,
dealer_name: str | None = None,
xjules marked this conversation as resolved.
Show resolved Hide resolved
ack_timeout: float | None = None,
) -> None:
if max_retries is None:
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT
self.url = url
self.token = token
self._additional_headers = Headers()

self._ack_event: asyncio.Event = asyncio.Event()
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
# this is to avoid blocking the event loop when closing the socket
# wherein the linger is set to 0 to discard all messages in the queue
self.socket.setsockopt(zmq.LINGER, 0)
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
self.dealer_id = dealer_name or f"dispatch-{uuid.uuid4().hex[:8]}"
self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id)

if token is not None:
self._additional_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
self._ssl_context: bool | ssl.SSLContext | None = None
if cert is not None:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.load_verify_locations(cadata=cert)
elif url.startswith("wss"):
self._ssl_context = True

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: ClientConnection | None = None
self.loop = new_event_loop()

async def get_websocket(self) -> ClientConnection:
return await connect(
self.url,
ssl=self._ssl_context,
additional_headers=self._additional_headers,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")

self._receiver_task: asyncio.Task[None] | None = None

async def __aenter__(self) -> Self:
await self.connect()
return self

async def _send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
xjules marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
try:
await self.send(DISCONNECT_MSG)
except ClientConnectionError:
logger.error("No ack for dealer disconnection. Connection is down!")
finally:
self.socket.disconnect(self.url)
xjules marked this conversation as resolved.
Show resolved Hide resolved
await self._term_receiver_task()
self.term()

def term(self) -> None:
self.socket.close()
self.context.term()

async def _term_receiver_task(self) -> None:
if self._receiver_task and not self._receiver_task.done():
self._receiver_task.cancel()
await asyncio.gather(self._receiver_task, return_exceptions=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to gather this cancelled task, or can we await it (and suppress asyncio.cancellationerror if it is raised)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we capture asyncio.CancelledError we can just do await. :+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the CancelledError is not catched there, so I would keep it.

self._receiver_task = None

async def connect(self) -> None:
self.socket.connect(self.url)
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self.send(CONNECT_MSG, retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise

async def process_message(self, msg: str) -> None:
raise NotImplementedError("Only monitor can receive messages!")

async def _receiver(self) -> None:
while True:
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
try:
if self.websocket is None:
self.websocket = await self.get_websocket()
await self.websocket.send(msg)
return
except ConnectionClosedOK as exception:
error_msg = (
f"Connection closed received from the server {self.url}! "
f" Exception from {type(exception)}: {exception!s}"
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == ACK_MSG:
self._ack_event.set()
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
raise ClientConnectionClosedOK(error_msg) from exception
except (TimeoutError, InvalidHandshake, InvalidURI, OSError) as exception:
if retry == self._max_retries:
error_msg = (
f"Not able to establish the "
f"websocket connection {self.url}! Max retries reached!"
" Check for firewall issues."
f" Exception from {type(exception)}: {exception!s}"
await asyncio.sleep(0)
self.socket.connect(self.url)

async def send(self, message: str | bytes, retries: int | None = None) -> None:
self._ack_event.clear()
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(message, str):
message = message.encode("utf-8")

backoff = 1
if retries is None:
retries = self.DEFAULT_MAX_RETRIES
while retries >= 0:
try:
await self.socket.send_multipart([b"", message])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._ack_timeout
)
raise ClientConnectionError(error_msg) from exception
except ConnectionClosedError as exception:
if retry == self._max_retries:
error_msg = (
f"Not been able to send the event"
f" to {self.url}! Max retries reached!"
f" Exception from {type(exception)}: {exception!s}"
return
except TimeoutError:
logger.warning(
f"{self.dealer_id} failed to get acknowledgment on the {message!r}. Resending."
)
raise ClientConnectionError(error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
except asyncio.CancelledError:
self.term()
raise

retries -= 1
if retries > 0:
logger.info(f"Retrying... ({retries} attempts left)")
xjules marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(backoff)
# this call is idempotent
self.socket.connect(self.url)
backoff = min(backoff * 2, 10) # Exponential backoff
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message!r} after retries!"
)
Loading
Loading