diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 18444c3f..927afe50 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -122,6 +122,25 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: **kwargs ) trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("connection.start_tls", request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + return stream except (ConnectError, ConnectTimeout): if retries_left <= 0: raise @@ -129,27 +148,6 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: delay = next(delays) # TRACE 'retry' await self._network_backend.sleep(delay) - else: - break - - if self._origin.scheme == b"https": - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._origin.host.decode("ascii"), - "timeout": timeout, - } - async with Trace("connection.start_tls", request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) - trace.return_value = stream - return stream def can_handle_request(self, origin: Origin) -> bool: return origin == self._origin diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 9c703498..4a03462d 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -122,6 +122,25 @@ def _connect(self, request: Request) -> NetworkStream: **kwargs ) trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("connection.start_tls", request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + return stream except (ConnectError, ConnectTimeout): if retries_left <= 0: raise @@ -129,27 +148,6 @@ def _connect(self, request: Request) -> NetworkStream: delay = next(delays) # TRACE 'retry' self._network_backend.sleep(delay) - else: - break - - if self._origin.scheme == b"https": - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._origin.host.decode("ascii"), - "timeout": timeout, - } - with Trace("connection.start_tls", request, kwargs) as trace: - stream = stream.start_tls(**kwargs) - trace.return_value = stream - return stream def can_handle_request(self, origin: Origin) -> bool: return origin == self._origin diff --git a/tests/_async/test_connection.py b/tests/_async/test_connection.py index 8f1f19e5..90e69f57 100644 --- a/tests/_async/test_connection.py +++ b/tests/_async/test_connection.py @@ -1,3 +1,5 @@ +import ssl +import typing from typing import List, Optional import hpack @@ -124,8 +126,15 @@ async def test_request_to_incorrect_origin(): class NeedsRetryBackend(AsyncMockBackend): - def __init__(self, buffer: List[bytes], http2: bool = False) -> None: - self._retry = 2 + def __init__( + self, + buffer: List[bytes], + http2: bool = False, + connect_tcp_failures: int = 2, + start_tls_failures: int = 0, + ) -> None: + self._connect_tcp_failures = connect_tcp_failures + self._start_tls_failures = start_tls_failures super().__init__(buffer, http2) async def connect_tcp( @@ -135,13 +144,50 @@ async def connect_tcp( timeout: Optional[float] = None, local_address: Optional[str] = None, ) -> AsyncNetworkStream: - if self._retry > 0: - self._retry -= 1 + if self._connect_tcp_failures > 0: + self._connect_tcp_failures -= 1 raise ConnectError() - return await super().connect_tcp( + stream = await super().connect_tcp( host, port, timeout=timeout, local_address=local_address ) + return self._NeedsRetryAsyncNetworkStream(self, stream) + + class _NeedsRetryAsyncNetworkStream(AsyncNetworkStream): + def __init__( + self, backend: "NeedsRetryBackend", stream: AsyncNetworkStream + ) -> None: + self._backend = backend + self._stream = stream + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + return await self._stream.read(max_bytes, timeout) + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + await self._stream.write(buffer, timeout) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "AsyncNetworkStream": + if self._backend._start_tls_failures > 0: + self._backend._start_tls_failures -= 1 + raise ConnectError() + + stream = await self._stream.start_tls(ssl_context, server_hostname, timeout) + return self._backend._NeedsRetryAsyncNetworkStream(self._backend, stream) + + def get_extra_info(self, info: str) -> typing.Any: + return self._stream.get_extra_info(info) @pytest.mark.anyio @@ -171,6 +217,37 @@ async def test_connection_retries(): await conn.request("GET", "https://example.com/") +@pytest.mark.anyio +async def test_connection_retries_tls(): + origin = Origin(b"https", b"example.com", 443) + content = [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + + network_backend = NeedsRetryBackend( + content, connect_tcp_failures=0, start_tls_failures=2 + ) + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, retries=3 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + network_backend = NeedsRetryBackend( + content, connect_tcp_failures=0, start_tls_failures=2 + ) + async with AsyncHTTPConnection( + origin=origin, + network_backend=network_backend, + ) as conn: + with pytest.raises(ConnectError): + await conn.request("GET", "https://example.com/") + + @pytest.mark.anyio async def test_uds_connections(): # We're not actually testing Unix Domain Sockets here, because we're just diff --git a/tests/_sync/test_connection.py b/tests/_sync/test_connection.py index d5c97856..9c18cc04 100644 --- a/tests/_sync/test_connection.py +++ b/tests/_sync/test_connection.py @@ -1,3 +1,5 @@ +import ssl +import typing from typing import List, Optional import hpack @@ -124,8 +126,15 @@ def test_request_to_incorrect_origin(): class NeedsRetryBackend(MockBackend): - def __init__(self, buffer: List[bytes], http2: bool = False) -> None: - self._retry = 2 + def __init__( + self, + buffer: List[bytes], + http2: bool = False, + connect_tcp_failures: int = 2, + start_tls_failures: int = 0, + ) -> None: + self._connect_tcp_failures = connect_tcp_failures + self._start_tls_failures = start_tls_failures super().__init__(buffer, http2) def connect_tcp( @@ -135,13 +144,50 @@ def connect_tcp( timeout: Optional[float] = None, local_address: Optional[str] = None, ) -> NetworkStream: - if self._retry > 0: - self._retry -= 1 + if self._connect_tcp_failures > 0: + self._connect_tcp_failures -= 1 raise ConnectError() - return super().connect_tcp( + stream = super().connect_tcp( host, port, timeout=timeout, local_address=local_address ) + return self._NeedsRetryAsyncNetworkStream(self, stream) + + class _NeedsRetryAsyncNetworkStream(NetworkStream): + def __init__( + self, backend: "NeedsRetryBackend", stream: NetworkStream + ) -> None: + self._backend = backend + self._stream = stream + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + return self._stream.read(max_bytes, timeout) + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self._stream.write(buffer, timeout) + + def close(self) -> None: + self._stream.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "NetworkStream": + if self._backend._start_tls_failures > 0: + self._backend._start_tls_failures -= 1 + raise ConnectError() + + stream = self._stream.start_tls(ssl_context, server_hostname, timeout) + return self._backend._NeedsRetryAsyncNetworkStream(self._backend, stream) + + def get_extra_info(self, info: str) -> typing.Any: + return self._stream.get_extra_info(info) @@ -172,6 +218,37 @@ def test_connection_retries(): +def test_connection_retries_tls(): + origin = Origin(b"https", b"example.com", 443) + content = [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + + network_backend = NeedsRetryBackend( + content, connect_tcp_failures=0, start_tls_failures=2 + ) + with HTTPConnection( + origin=origin, network_backend=network_backend, retries=3 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + network_backend = NeedsRetryBackend( + content, connect_tcp_failures=0, start_tls_failures=2 + ) + with HTTPConnection( + origin=origin, + network_backend=network_backend, + ) as conn: + with pytest.raises(ConnectError): + conn.request("GET", "https://example.com/") + + + def test_uds_connections(): # We're not actually testing Unix Domain Sockets here, because we're just # using a mock backend, but at least we're covering the UDS codepath