diff --git a/changelog.d/20240116_164958_rra_DM_42527.md b/changelog.d/20240116_164958_rra_DM_42527.md new file mode 100644 index 00000000..15cafa06 --- /dev/null +++ b/changelog.d/20240116_164958_rra_DM_42527.md @@ -0,0 +1,3 @@ +### Bug fixes + +- Rewrite `CaseInsensitiveQueryMiddleware` and `XForwardedMiddleware` as pure ASGI middleware rather than using the Starlette `BaseHTTPMiddleware` class. The latter seems to be behind some poor error reporting of application exceptions, has caused problems in the past due to its complexity, and is not used internally by Starlette middleware. diff --git a/src/safir/middleware/ivoa.py b/src/safir/middleware/ivoa.py index 6b8c2af8..9c4b93cf 100644 --- a/src/safir/middleware/ivoa.py +++ b/src/safir/middleware/ivoa.py @@ -1,23 +1,22 @@ """Middleware for IVOA services.""" -from collections.abc import Awaitable, Callable -from urllib.parse import urlencode +from copy import copy +from urllib.parse import parse_qsl, urlencode -from fastapi import Request, Response -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp, Receive, Scope, Send __all__ = ["CaseInsensitiveQueryMiddleware"] -class CaseInsensitiveQueryMiddleware(BaseHTTPMiddleware): +class CaseInsensitiveQueryMiddleware: """Make query parameter keys all lowercase. Unfortunately, several IVOA standards require that query parameters be case-insensitive, which is not supported by modern HTTP web frameworks. This middleware attempts to work around this by lowercasing the query parameter keys before the request is processed, allowing normal FastAPI - query parsing to then work without regard for case. This, in turn, - permits FastAPI to perform input validation on GET parameters, which would + query parsing to then work without regard for case. This, in turn, permits + FastAPI to perform input validation on GET parameters, which would otherwise only happen if the case used in the request happened to match the case used in the function signature. @@ -28,11 +27,17 @@ class CaseInsensitiveQueryMiddleware(BaseHTTPMiddleware): Based on `fastapi#826 `__. """ - async def dispatch( - self, - request: Request, - call_next: Callable[[Request], Awaitable[Response]], - ) -> Response: - params = [(k.lower(), v) for k, v in request.query_params.items()] - request.scope["query_string"] = urlencode(params).encode() - return await call_next(request) + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + if scope["type"] != "http" or not scope.get("query_string"): + await self._app(scope, receive, send) + return + scope = copy(scope) + params = [(k.lower(), v) for k, v in parse_qsl(scope["query_string"])] + scope["query_string"] = urlencode(params).encode() + await self._app(scope, receive, send) + return diff --git a/src/safir/middleware/x_forwarded.py b/src/safir/middleware/x_forwarded.py index fb6f26d8..2a3f40c1 100644 --- a/src/safir/middleware/x_forwarded.py +++ b/src/safir/middleware/x_forwarded.py @@ -2,17 +2,17 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from copy import copy from ipaddress import _BaseAddress, _BaseNetwork, ip_address -from fastapi import FastAPI, Request, Response -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.datastructures import Headers +from starlette.types import ASGIApp, Receive, Scope, Send __all__ = ["XForwardedMiddleware"] -class XForwardedMiddleware(BaseHTTPMiddleware): - """Middleware to update the request based on ``X-Forwarded-For``. +class XForwardedMiddleware: + """ASGI middleware to update the request based on ``X-Forwarded-For``. The remote IP address will be replaced with the right-most IP address in ``X-Forwarded-For`` that is not contained within one of the trusted @@ -20,60 +20,47 @@ class XForwardedMiddleware(BaseHTTPMiddleware): If ``X-Forwarded-For`` is found and ``X-Forwarded-Proto`` is also present, the corresponding entry of ``X-Forwarded-Proto`` is used to replace the - scheme in the request scope. If ``X-Forwarded-Proto`` only has one entry + scheme in the request scope. If ``X-Forwarded-Proto`` only has one entry (ingress-nginx has this behavior), that one entry will become the new scheme in the request scope. The contents of ``X-Forwarded-Host`` will be stored as ``forwarded_host`` - in the request state if it and ``X-Forwarded-For`` are present. Normally + in the request state if it and ``X-Forwarded-For`` are present. Normally this is not needed since NGINX will pass the original ``Host`` header without modification. Parameters ---------- proxies - The networks of the trusted proxies. If not specified, defaults to - the empty list, which means only the immediately upstream proxy will - be trusted. + The networks of the trusted proxies. If not specified, defaults to the + empty list, which means only the immediately upstream proxy will be + trusted. """ def __init__( - self, app: FastAPI, *, proxies: list[_BaseNetwork] | None = None + self, app: ASGIApp, *, proxies: list[_BaseNetwork] | None = None ) -> None: - super().__init__(app) - if proxies: - self.proxies = proxies - else: - self.proxies = [] - - async def dispatch( - self, - request: Request, - call_next: Callable[[Request], Awaitable[Response]], - ) -> Response: - """Middleware to update the request based on ``X-Forwarded-For``. - - Parameters - ---------- - request - The incoming request. - call_next - The next step in the processing stack. + self._app = app + self._proxies = proxies if proxies else [] - Returns - ------- - ``fastapi.Response`` - The response with additional information about proxy headers. - """ - forwarded_for = list(reversed(self._get_forwarded_for(request))) + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + if scope["type"] != "http": + await self._app(scope, receive, send) + return + scope = copy(scope) + scope.setdefault("state", {}) + headers = Headers(scope=scope) + forwarded_for = list(reversed(self._get_forwarded_for(headers))) if not forwarded_for: - request.state.forwarded_host = None - request.state.forwarded_proto = None - return await call_next(request) + scope["state"]["forwarded_host"] = None + await self._app(scope, receive, send) + return client = None for n, ip in enumerate(forwarded_for): - if any(ip in network for network in self.proxies): + if any(ip in network for network in self._proxies): continue client = str(ip) index = n @@ -85,45 +72,45 @@ async def dispatch( client = str(forwarded_for[-1]) index = -1 - # Update the request's understanding of the client IP. This uses an - # undocumented interface; hopefully it will keep working. - if request.client: - request.scope["client"] = (client, request.client.port) + # Update the request's understanding of the client IP. + if scope.get("client"): + scope["client"] = (client, scope["client"][1]) else: - request.scope["client"] = (client, None) + scope["client"] = (client, None) # Ideally this should take the scheme corresponding to the entry in # X-Forwarded-For that was chosen, but some proxies (the Kubernetes # NGINX ingress, for example) only retain one element in - # X-Forwarded-Proto. In that case, use what we have. - proto = list(reversed(self._get_forwarded_proto(request))) + # X-Forwarded-Proto. In that case, use what we have. + proto = list(reversed(self._get_forwarded_proto(headers))) if proto: if index >= len(proto): index = -1 - request.scope["scheme"] = proto[index] + scope["scheme"] = proto[index] - # Rather than one entry per hop, NGINX seems to add only a single - # X-Forwarded-Host header with the original hostname. - request.state.forwarded_host = self._get_forwarded_host(request) + # Record what appears to be the client host for logging purposes. + scope["state"]["forwarded_host"] = self._get_forwarded_host(headers) - return await call_next(request) + # Perform the rest of the request processing. + await self._app(scope, receive, send) + return - def _get_forwarded_for(self, request: Request) -> list[_BaseAddress]: + def _get_forwarded_for(self, headers: Headers) -> list[_BaseAddress]: """Retrieve the ``X-Forwarded-For`` entries from the request. Parameters ---------- - request - The incoming request. + scope + Request headers. Returns ------- list of ipaddress._BaseAddress - The list of addresses found in the header. If there are multiple + The list of addresses found in the header. If there are multiple ``X-Forwarded-For`` headers, we don't know which one is correct, so act as if there are no headers. """ - forwarded_for_str = request.headers.getlist("X-Forwarded-For") + forwarded_for_str = headers.getlist("X-Forwarded-For") if not forwarded_for_str or len(forwarded_for_str) > 1: return [] return [ @@ -132,43 +119,43 @@ def _get_forwarded_for(self, request: Request) -> list[_BaseAddress]: if addr ] - def _get_forwarded_host(self, request: Request) -> str | None: + def _get_forwarded_host(self, headers: Headers) -> str | None: """Retrieve the ``X-Forwarded-Host`` header. Parameters ---------- - request - The incoming request. + headers + Request headers. Returns ------- str The value of the ``X-Forwarded-Host`` header, if present and if - there is only one header. If there are multiple + there is only one header. If there are multiple ``X-Forwarded-Host`` headers, we don't know which one is correct, so act as if there are no headers. """ - forwarded_host = request.headers.getlist("X-Forwarded-Host") + forwarded_host = headers.getlist("X-Forwarded-Host") if not forwarded_host or len(forwarded_host) > 1: return None return forwarded_host[0].strip() - def _get_forwarded_proto(self, request: Request) -> list[str]: + def _get_forwarded_proto(self, headers: Headers) -> list[str]: """Retrieve the ``X-Forwarded-Proto`` entries from the request. Parameters ---------- - request - The incoming request. + headers + Request headers. Returns ------- list of str - The list of schemes found in the header. If there are multiple + The list of schemes found in the header. If there are multiple ``X-Forwarded-Proto`` headers, we don't know which one is correct, so act as if there are no headers. """ - forwarded_proto_str = request.headers.getlist("X-Forwarded-Proto") + forwarded_proto_str = headers.getlist("X-Forwarded-Proto") if not forwarded_proto_str or len(forwarded_proto_str) > 1: return [] return [p.strip() for p in forwarded_proto_str[0].split(",")] diff --git a/tests/middleware/ivoa_test.py b/tests/middleware/ivoa_test.py index e76e7c5f..f7473eae 100644 --- a/tests/middleware/ivoa_test.py +++ b/tests/middleware/ivoa_test.py @@ -2,8 +2,10 @@ from __future__ import annotations +from typing import Annotated + import pytest -from fastapi import FastAPI +from fastapi import FastAPI, Query from httpx import AsyncClient from safir.middleware.ivoa import CaseInsensitiveQueryMiddleware @@ -24,6 +26,16 @@ async def test_case_insensitive() -> None: async def handler(param: str) -> dict[str, str]: return {"param": param} + @app.get("/simple") + async def simple_handler() -> dict[str, str]: + return {"foo": "bar"} + + @app.get("/list") + async def list_handler( + param: Annotated[list[str], Query()] + ) -> dict[str, list[str]]: + return {"param": param} + async with AsyncClient(app=app, base_url="https://example.com") as client: r = await client.get("/", params={"param": "foo"}) assert r.status_code == 200 @@ -39,3 +51,14 @@ async def handler(param: str) -> dict[str, str]: r = await client.get("/", params={"paramX": "foo"}) assert r.status_code == 422 + + r = await client.get("/simple") + assert r.status_code == 200 + assert r.json() == {"foo": "bar"} + + r = await client.get( + "/list", + params=[("param", "foo"), ("PARAM", "BAR"), ("parAM", "baZ")], + ) + assert r.status_code == 200 + assert r.json() == {"param": ["foo", "BAR", "baZ"]} diff --git a/tests/middleware/x_forwarded_test.py b/tests/middleware/x_forwarded_test.py index 69432c3f..6dec36e2 100644 --- a/tests/middleware/x_forwarded_test.py +++ b/tests/middleware/x_forwarded_test.py @@ -7,6 +7,7 @@ import pytest from fastapi import FastAPI, Request from httpx import AsyncClient +from starlette.datastructures import Headers from safir.middleware.x_forwarded import XForwardedMiddleware @@ -159,7 +160,7 @@ async def test_too_many_headers() -> None: end. Instead, test by generating a mock request and then calling the underling middleware functions directly. """ - state = { + scope = { "type": "http", "headers": [ ("X-Forwarded-For", "10.10.10.10"), @@ -170,9 +171,9 @@ async def test_too_many_headers() -> None: ("X-Forwarded-Host", "example.com"), ], } - request = Request(state) + headers = Headers(scope=scope) app = FastAPI() middleware = XForwardedMiddleware(app, proxies=[ip_network("10.0.0.0/8")]) - assert middleware._get_forwarded_for(request) == [] - assert middleware._get_forwarded_proto(request) == [] - assert not middleware._get_forwarded_host(request) + assert middleware._get_forwarded_for(headers) == [] + assert middleware._get_forwarded_proto(headers) == [] + assert not middleware._get_forwarded_host(headers)