Skip to content

Commit

Permalink
Merge branch 'main' into bugfix/fix_leaking_clients_in_actor_caller
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Nov 26, 2024
2 parents 8f5cc09 + 7b9f181 commit 7535d67
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 60 deletions.
10 changes: 8 additions & 2 deletions python/xoscar/backends/communication/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from urllib.parse import urlparse

from ..._utils import to_binary
from ...constants import XOSCAR_UNIX_SOCKET_DIR
from ...constants import XOSCAR_CONNECT_TIMEOUT, XOSCAR_UNIX_SOCKET_DIR
from ...serialization import AioDeserializer, AioSerializer, deserialize
from ...utils import classproperty, implements, is_py_312, is_v6_ip
from .base import Channel, ChannelType, Client, Server
Expand Down Expand Up @@ -291,7 +291,13 @@ async def connect(
) -> "Client":
host, port_str = dest_address.rsplit(":", 1)
port = int(port_str)
(reader, writer) = await asyncio.open_connection(host=host, port=port, **kwargs)
config = kwargs.get("config", {})
connect_timeout = config.get("connect_timeout", XOSCAR_CONNECT_TIMEOUT)
fut = asyncio.open_connection(host=host, port=port)
try:
reader, writer = await asyncio.wait_for(fut, timeout=connect_timeout)
except asyncio.TimeoutError:
raise ConnectionError("connect timeout")
channel = SocketChannel(
reader, writer, local_address=local_address, dest_address=dest_address
)
Expand Down
91 changes: 51 additions & 40 deletions python/xoscar/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,50 +72,61 @@ async def get_client(self, router: Router, dest_address: str) -> Client:
return client

async def _listen(self, client: Client):
while not client.closed:
try:
try:
while not client.closed:
try:
message: _MessageBase = await client.recv()
except (EOFError, ConnectionError, BrokenPipeError):
# remote server closed, close client and raise ServerClosed
try:
await client.close()
except (ConnectionError, BrokenPipeError):
# close failed, ignore it
message: _MessageBase = await client.recv()
except (EOFError, ConnectionError, BrokenPipeError) as e:
# AssertionError is from get_header
# remote server closed, close client and raise ServerClosed
logger.debug(f"{client.dest_address} close due to {e}")
try:
await client.close()
except (ConnectionError, BrokenPipeError):
# close failed, ignore it
pass
raise ServerClosed(
f"Remote server {client.dest_address} closed: {e}"
) from None
future = self._client_to_message_futures[client].pop(
message.message_id
)
if not future.done():
future.set_result(message)
except DeserializeMessageFailed as e:
message_id = e.message_id
future = self._client_to_message_futures[client].pop(message_id)
future.set_exception(e.__cause__) # type: ignore
except Exception as e: # noqa: E722 # pylint: disable=bare-except
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
for future in message_futures.values():
future.set_exception(copy.copy(e))
finally:
# message may have Ray ObjectRef, delete it early in case next loop doesn't run
# as soon as expected.
try:
del message
except NameError:
pass
raise ServerClosed(
f"Remote server {client.dest_address} closed"
) from None
future = self._client_to_message_futures[client].pop(message.message_id)
if not future.done():
future.set_result(message)
except DeserializeMessageFailed as e:
message_id = e.message_id
future = self._client_to_message_futures[client].pop(message_id)
future.set_exception(e.__cause__) # type: ignore
except Exception as e: # noqa: E722 # pylint: disable=bare-except
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
for future in message_futures.values():
future.set_exception(copy.copy(e))
finally:
# message may have Ray ObjectRef, delete it early in case next loop doesn't run
# as soon as expected.
try:
del message
except NameError:
pass
try:
del future
except NameError:
pass
await asyncio.sleep(0)
try:
del future
except NameError:
pass
await asyncio.sleep(0)

message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
error = ServerClosed(f"Remote server {client.dest_address} closed")
for future in message_futures.values():
future.set_exception(copy.copy(error))
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
error = ServerClosed(f"Remote server {client.dest_address} closed")
for future in message_futures.values():
future.set_exception(copy.copy(error))
finally:
try:
await client.close()
except: # noqa: E722 # nosec # pylint: disable=bare-except
# ignore all error if fail to close at last
pass

async def call_with_client(
self, client: Client, message: _MessageBase, wait: bool = True
Expand Down
2 changes: 1 addition & 1 deletion python/xoscar/backends/indigen/tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ async def test_server_closed():

# check if error raised normally when subprocess killed
task = asyncio.create_task(actor_ref.sleep(10))
await asyncio.sleep(0)
await asyncio.sleep(0.1)

# kill subprocess 1
process = list(pool._sub_processes.values())[0]
Expand Down
40 changes: 24 additions & 16 deletions python/xoscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,23 +551,31 @@ async def _handle_ucx_meta_message(
return False

async def on_new_channel(self, channel: Channel):
while not self._stopped.is_set():
try:
message = await channel.recv()
except EOFError:
# no data to read, check channel
try:
while not self._stopped.is_set():
try:
await channel.close()
except (ConnectionError, EOFError):
# close failed, ignore
pass
return
if await self._handle_ucx_meta_message(message, channel):
continue
asyncio.create_task(self.process_message(message, channel))
# delete to release the reference of message
del message
await asyncio.sleep(0)
message = await channel.recv()
except (EOFError, ConnectionError, BrokenPipeError) as e:
logger.debug(f"pool: close connection due to {e}")
# no data to read, check channel
try:
await channel.close()
except (ConnectionError, EOFError):
# close failed, ignore
pass
return
if await self._handle_ucx_meta_message(message, channel):
continue
asyncio.create_task(self.process_message(message, channel))
# delete to release the reference of message
del message
await asyncio.sleep(0)
finally:
try:
await channel.close()
except: # noqa: E722 # nosec # pylint: disable=bare-except
# ignore all error if fail to close at last
pass

async def __aenter__(self):
await self.start()
Expand Down
2 changes: 2 additions & 0 deletions python/xoscar/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@

# unix socket.
XOSCAR_UNIX_SOCKET_DIR = XOSCAR_TEMP_DIR / "socket"

XOSCAR_CONNECT_TIMEOUT = 8
6 changes: 5 additions & 1 deletion python/xoscar/serialization/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ async def run(self):
def get_header_length(header_bytes: bytes):
version = struct.unpack("B", header_bytes[:1])[0]
# now we only have default version
assert version == DEFAULT_SERIALIZATION_VERSION, MALFORMED_MSG
if version != DEFAULT_SERIALIZATION_VERSION:
# when version not matched,
# we will immediately abort the connection
# EOFError will be captured by channel
raise EOFError(MALFORMED_MSG)
# header length
header_length = struct.unpack("<Q", header_bytes[1:9])[0]
# compress
Expand Down

0 comments on commit 7535d67

Please sign in to comment.