Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry ConnectError/ConnectTimeout happened in stream.start_tls #669

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,32 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
**kwargs
)
trace.return_value = stream

if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("connection.start_tls", request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
return stream
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
# TRACE 'retry'
await self._network_backend.sleep(delay)
else:
break

if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("connection.start_tls", request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
return stream

def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
Expand Down
40 changes: 19 additions & 21 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,32 @@ def _connect(self, request: Request) -> NetworkStream:
**kwargs
)
trace.return_value = stream

if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("connection.start_tls", request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
return stream
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
# TRACE 'retry'
self._network_backend.sleep(delay)
else:
break

if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("connection.start_tls", request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
return stream

def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
Expand Down
87 changes: 82 additions & 5 deletions tests/_async/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ssl
import typing
from typing import List, Optional

import hpack
Expand Down Expand Up @@ -124,8 +126,15 @@ async def test_request_to_incorrect_origin():


class NeedsRetryBackend(AsyncMockBackend):
def __init__(self, buffer: List[bytes], http2: bool = False) -> None:
self._retry = 2
def __init__(
self,
buffer: List[bytes],
http2: bool = False,
connect_tcp_failures: int = 2,
start_tls_failures: int = 0,
) -> None:
self._connect_tcp_failures = connect_tcp_failures
self._start_tls_failures = start_tls_failures
super().__init__(buffer, http2)

async def connect_tcp(
Expand All @@ -135,13 +144,50 @@ async def connect_tcp(
timeout: Optional[float] = None,
local_address: Optional[str] = None,
) -> AsyncNetworkStream:
if self._retry > 0:
self._retry -= 1
if self._connect_tcp_failures > 0:
self._connect_tcp_failures -= 1
raise ConnectError()

return await super().connect_tcp(
stream = await super().connect_tcp(
host, port, timeout=timeout, local_address=local_address
)
return self._NeedsRetryAsyncNetworkStream(self, stream)

class _NeedsRetryAsyncNetworkStream(AsyncNetworkStream):
def __init__(
self, backend: "NeedsRetryBackend", stream: AsyncNetworkStream
) -> None:
self._backend = backend
self._stream = stream

async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
return await self._stream.read(max_bytes, timeout)

async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
await self._stream.write(buffer, timeout)

async def aclose(self) -> None:
await self._stream.aclose()

async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "AsyncNetworkStream":
if self._backend._start_tls_failures > 0:
self._backend._start_tls_failures -= 1
raise ConnectError()

stream = await self._stream.start_tls(ssl_context, server_hostname, timeout)
return self._backend._NeedsRetryAsyncNetworkStream(self._backend, stream)

def get_extra_info(self, info: str) -> typing.Any:
return self._stream.get_extra_info(info)


@pytest.mark.anyio
Expand Down Expand Up @@ -171,6 +217,37 @@ async def test_connection_retries():
await conn.request("GET", "https://example.com/")


@pytest.mark.anyio
async def test_connection_retries_tls():
origin = Origin(b"https", b"example.com", 443)
content = [
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]

network_backend = NeedsRetryBackend(
content, connect_tcp_failures=0, start_tls_failures=2
)
async with AsyncHTTPConnection(
origin=origin, network_backend=network_backend, retries=3
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200

network_backend = NeedsRetryBackend(
content, connect_tcp_failures=0, start_tls_failures=2
)
async with AsyncHTTPConnection(
origin=origin,
network_backend=network_backend,
) as conn:
with pytest.raises(ConnectError):
await conn.request("GET", "https://example.com/")


@pytest.mark.anyio
async def test_uds_connections():
# We're not actually testing Unix Domain Sockets here, because we're just
Expand Down
87 changes: 82 additions & 5 deletions tests/_sync/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ssl
import typing
from typing import List, Optional

import hpack
Expand Down Expand Up @@ -124,8 +126,15 @@ def test_request_to_incorrect_origin():


class NeedsRetryBackend(MockBackend):
def __init__(self, buffer: List[bytes], http2: bool = False) -> None:
self._retry = 2
def __init__(
self,
buffer: List[bytes],
http2: bool = False,
connect_tcp_failures: int = 2,
start_tls_failures: int = 0,
) -> None:
self._connect_tcp_failures = connect_tcp_failures
self._start_tls_failures = start_tls_failures
super().__init__(buffer, http2)

def connect_tcp(
Expand All @@ -135,13 +144,50 @@ def connect_tcp(
timeout: Optional[float] = None,
local_address: Optional[str] = None,
) -> NetworkStream:
if self._retry > 0:
self._retry -= 1
if self._connect_tcp_failures > 0:
self._connect_tcp_failures -= 1
raise ConnectError()

return super().connect_tcp(
stream = super().connect_tcp(
host, port, timeout=timeout, local_address=local_address
)
return self._NeedsRetryAsyncNetworkStream(self, stream)

class _NeedsRetryAsyncNetworkStream(NetworkStream):
def __init__(
self, backend: "NeedsRetryBackend", stream: NetworkStream
) -> None:
self._backend = backend
self._stream = stream

def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
return self._stream.read(max_bytes, timeout)

def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
self._stream.write(buffer, timeout)

def close(self) -> None:
self._stream.close()

def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "NetworkStream":
if self._backend._start_tls_failures > 0:
self._backend._start_tls_failures -= 1
raise ConnectError()

stream = self._stream.start_tls(ssl_context, server_hostname, timeout)
return self._backend._NeedsRetryAsyncNetworkStream(self._backend, stream)

def get_extra_info(self, info: str) -> typing.Any:
return self._stream.get_extra_info(info)



Expand Down Expand Up @@ -172,6 +218,37 @@ def test_connection_retries():



def test_connection_retries_tls():
origin = Origin(b"https", b"example.com", 443)
content = [
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]

network_backend = NeedsRetryBackend(
content, connect_tcp_failures=0, start_tls_failures=2
)
with HTTPConnection(
origin=origin, network_backend=network_backend, retries=3
) as conn:
response = conn.request("GET", "https://example.com/")
assert response.status == 200

network_backend = NeedsRetryBackend(
content, connect_tcp_failures=0, start_tls_failures=2
)
with HTTPConnection(
origin=origin,
network_backend=network_backend,
) as conn:
with pytest.raises(ConnectError):
conn.request("GET", "https://example.com/")



def test_uds_connections():
# We're not actually testing Unix Domain Sockets here, because we're just
# using a mock backend, but at least we're covering the UDS codepath
Expand Down