Skip to content

Commit

Permalink
Merge pull request #354 from backtick-se/fix-network-timeout
Browse files Browse the repository at this point in the history
Avoid socket timeout if IO loop is overloaded
  • Loading branch information
Martomate authored Apr 1, 2022
2 parents a931ca0 + c2c0bcc commit 439cb7d
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 156 deletions.
3 changes: 1 addition & 2 deletions cowait/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ async def _connect(self, url: str, token: str) -> None:
async with session.ws_connect(
url,
headers={'Authorization': f'Bearer {token}'},
autoping=True,
heartbeat=5.0,
autoping=False,
timeout=30.0,
) as ws:
self.ws = ws
Expand Down
39 changes: 10 additions & 29 deletions cowait/network/server.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,21 @@
from asyncio import CancelledError
from asyncio import CancelledError, TimeoutError
from aiohttp import web, WSMsgType
from aiohttp.helpers import call_later
from aiohttp_middlewares import cors_middleware
from datetime import datetime
from cowait.utils import EventEmitter
from .conn import Conn
from .const import WS_PATH, ON_CONNECT, ON_CLOSE, ON_ERROR
from .const import WS_PATH, ON_CONNECT, ON_CLOSE
from .auth_middleware import AuthMiddleware
from .errors import SocketError


class FixedWebSocketResponse(web.WebSocketResponse):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def _ping(self):
try:
await self._writer.ping()
except ConnectionResetError:
pass

def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
assert self._loop is not None
self._loop.create_task(self._ping()) # type: ignore[union-attr]

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop
)


class Server(EventEmitter):
def __init__(self, port, middlewares: list = []):
super().__init__()
self.conns = []
self.port = port
self.auth = AuthMiddleware()
self.timeout = 30.0

# create http app
self.app = web.Application(
Expand All @@ -57,10 +35,10 @@ def __init__(self, port, middlewares: list = []):
self.add_get(f'/{WS_PATH}', self.handle_client)

async def handle_client(self, request):
ws = FixedWebSocketResponse(
timeout=30.0,
autoping=True,
heartbeat=5.0,
ws = web.WebSocketResponse(
timeout=self.timeout,
receive_timeout=self.timeout,
autoping=False,
)
await ws.prepare(request)

Expand Down Expand Up @@ -91,6 +69,9 @@ async def handle_client(self, request):
except SocketError as e:
await self.emit(type=ON_CLOSE, conn=conn, error=str(e))

except TimeoutError as e:
await self.emit(type=ON_CLOSE, conn=conn, error=str(e))

finally:
# disconnected
self.conns.remove(conn)
Expand Down
Loading

0 comments on commit 439cb7d

Please sign in to comment.