Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SSLProtocol.connection_lost not being called when underlying socket is closed #639

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
skip_tests = False

import asyncio
import os
import sys
import unittest
import weakref

from uvloop import _testbase as tb


class _TestAioHTTP:
class _TestAioHTTP(tb.SSLTestCase):

def test_aiohttp_basic_1(self):

Expand Down Expand Up @@ -115,6 +116,64 @@ async def stop():

self.loop.run_until_complete(stop())

def test_aiohttp_connection_lost_when_busy(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.')

cert = tb._cert_fullname(__file__, 'ssl_cert.pem')
key = tb._cert_fullname(__file__, 'ssl_key.pem')
ssl_context = self._create_server_ssl_context(cert, key)
client_ssl_context = self._create_client_ssl_context()

asyncio.set_event_loop(self.loop)
app = aiohttp.web.Application()

async def handler(request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
print("Received:", msg.data)
return ws

app.router.add_get('/', handler)

runner = aiohttp.web.AppRunner(app)
self.loop.run_until_complete(runner.setup())
host = '0.0.0.0'
site = aiohttp.web.TCPSite(runner, host, '0', ssl_context=ssl_context)
self.loop.run_until_complete(site.start())
port = site._server.sockets[0].getsockname()[1]
session = aiohttp.ClientSession(loop=self.loop)

async def test():
async with session.ws_connect(
f"wss://{host}:{port}/",
ssl=client_ssl_context
) as ws:
transport = ws._writer.transport
s = transport.get_extra_info('socket')

if self.implementation == 'asyncio':
s._sock.close()
else:
os.close(s.fileno())

# FLOW_CONTROL_HIGH_WATER * 1024
bytes_to_send = 64 * 1024
iterations = 10
msg = b'Hello world, still there?'

# Send enough messages to trigger a socket write + one extra
for _ in range(iterations + 1):
await ws.send_bytes(
msg * ((bytes_to_send // len(msg)) // iterations))

self.assertRaises(
ConnectionResetError, self.loop.run_until_complete, test())

self.loop.run_until_complete(session.close())
self.loop.run_until_complete(runner.cleanup())


@unittest.skipIf(skip_tests, "no aiohttp module")
class Test_UV_AioHTTP(_TestAioHTTP, tb.UVTestCase):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import asyncio.sslproto
import contextlib
import gc
import os
import select
Expand Down Expand Up @@ -3192,6 +3193,57 @@ async def run_main():

self.loop.run_until_complete(run_main())

def test_connection_lost_when_busy(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.')

ssl_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_ssl_context = self._create_client_ssl_context()
port = tb.find_free_port()

@contextlib.asynccontextmanager
async def server():
async def client_handler(reader, writer):
...

srv = await asyncio.start_server(
client_handler, '0.0.0.0',
port, ssl=ssl_context, reuse_port=True)

try:
yield
finally:
srv.close()

async def client():
reader, writer = await asyncio.open_connection(
'0.0.0.0', port, ssl=client_ssl_context)
transport = writer.transport
s = transport.get_extra_info('socket')

if self.implementation == 'asyncio':
s._sock.close()
else:
os.close(s.fileno())

# FLOW_CONTROL_HIGH_WATER * 1024
bytes_to_send = 64 * 1024
iterations = 10
msg = b'An really important message :)'

# Busy drain loop
for _ in range(iterations + 1):
writer.write(msg * ((bytes_to_send // len(msg)) // iterations))
await writer.drain()

async def test():
async with server():
await client()

self.assertRaises(
ConnectionResetError, self.loop.run_until_complete, test())


class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
pass
Expand Down
5 changes: 4 additions & 1 deletion uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cdef class _SSLProtocolTransport:
return self._ssl_protocol._app_protocol

def is_closing(self):
return self._closed
return self._closed or self._ssl_protocol._is_transport_closing()

def close(self):
"""Close the transport.
Expand Down Expand Up @@ -316,6 +316,9 @@ cdef class SSLProtocol:
self._app_transport_created = True
return self._app_transport

def _is_transport_closing(self):
return self._transport is not None and self._transport.is_closing()

def connection_made(self, transport):
"""Called when the low-level connection is made.

Expand Down
Loading