Skip to content

Commit

Permalink
Merge pull request #340 from lsst-sqre/tickets/DM-47459
Browse files Browse the repository at this point in the history
DM-47459: Added middleware for converting form post params to lowercase
  • Loading branch information
stvoutsin authored Nov 26, 2024
2 parents 8d49db2 + 69fb70d commit d2e49a7
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 33 deletions.
5 changes: 5 additions & 0 deletions changelog.d/20241126_175555_steliosvoutsinas_DM_47459.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<!-- Delete the sections that don't apply -->

### New features

- Added CaseInsensitiveFormMiddleware to lowercase handle form post params for VO services
135 changes: 135 additions & 0 deletions safir/src/safir/middleware/ivoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,138 @@ async def __call__(
scope["query_string"] = urlencode(params).encode()
await self._app(scope, receive, send)
return


class CaseInsensitiveFormMiddleware:
"""Make POST parameter keys all lowercase.
This middleware attempts to work around case-sensitivity issues by
lowercasing POST parameter keys before the request is processed. This
allows normal FastAPI parsing to work without regard for case, permitting
FastAPI to perform input validation on the POST parameters.
"""

def __init__(self, app: ASGIApp) -> None:
"""Initialize the middleware with the ASGI application.
Parameters
----------
app
The ASGI application to wrap.
"""
self._app = app

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
"""Process request set query parameters and POST body keys to
lowercase.
"""
if scope["type"] != "http":
await self._app(scope, receive, send)
return

scope = copy(scope)

if scope["method"] == "POST" and self.is_form_data(scope):
receive = self.wrapped_receive(receive)

await self._app(scope, receive, send)

@staticmethod
def is_form_data(scope: Scope) -> bool:
"""Check if the request contains form data.
Parameters
----------
scope
The request scope.
Returns
-------
bool
True if the request contains form data, False otherwise.
"""
headers = {
k.decode("latin-1"): v.decode("latin-1")
for k, v in scope.get("headers", [])
}
content_type = headers.get("content-type", "")
return content_type.startswith("application/x-www-form-urlencoded")

@staticmethod
async def get_body(receive: Receive) -> bytes:
"""Read the entire request body.
Parameters
----------
receive
The receive function to read messages from.
Returns
-------
bytes
The entire request body.
"""
body = b""
more_body = True
while more_body:
message = await receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
return body

@staticmethod
async def process_form_data(body: bytes) -> bytes:
"""Process the body, lowercasing keys of form data.
Parameters
----------
body
The request body.
Returns
-------
bytes
The processed request body with lowercased keys.
"""
body_str = body.decode("utf-8")
parsed = parse_qsl(body_str)
lowercased = [(key.lower(), value) for key, value in parsed]
processed = urlencode(lowercased)
return processed.encode("utf-8")

def wrapped_receive(self, receive: Receive) -> Receive:
"""Wrap the receive function to process form data.
Parameters
----------
receive
The receive function to wrap.
Returns
-------
Receive
The wrapped receive function.
"""
processed = False

async def inner() -> dict:
nonlocal processed
if processed:
return {
"type": "http.request",
"body": b"",
"more_body": False,
}

body = await self.get_body(receive)
processed_body = await self.process_form_data(body)
processed = True
return {
"type": "http.request",
"body": processed_body,
"more_body": False,
}

return inner
125 changes: 92 additions & 33 deletions safir/tests/middleware/ivoa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

from __future__ import annotations

from collections.abc import AsyncGenerator
from typing import Annotated

import pytest
from fastapi import FastAPI, Query
import pytest_asyncio
from fastapi import FastAPI, Query, Request
from httpx import ASGITransport, AsyncClient

from safir.middleware.ivoa import CaseInsensitiveQueryMiddleware
from safir.middleware.ivoa import (
CaseInsensitiveFormMiddleware,
CaseInsensitiveQueryMiddleware,
)


def build_app() -> FastAPI:
"""Construct a test FastAPI app with the middleware registered."""
app = FastAPI()
app.add_middleware(CaseInsensitiveQueryMiddleware)
return app


@pytest.mark.asyncio
async def test_case_insensitive() -> None:
app = build_app()
app.add_middleware(CaseInsensitiveFormMiddleware)

@app.get("/")
async def handler(param: str) -> dict[str, str]:
Expand All @@ -36,31 +36,90 @@ async def list_handler(
) -> dict[str, list[str]]:
return {"param": param}

@app.post("/form-list")
async def form_handler(request: Request) -> dict[str, list[str]]:
form = await request.form()
return {
"param": [str(v) for v in form.getlist("param")],
"received_keys": list(form.keys()),
}

return app


@pytest_asyncio.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
"""Test client fixture with the IVOA middleware configured."""
app = build_app()
transport = ASGITransport(app=app)
base_url = "https://example.com"
async with AsyncClient(transport=transport, base_url=base_url) as client:
r = await client.get("/", params={"param": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}

r = await client.get("/", params={"PARAM": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}

r = await client.get("/", params={"pARam": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}

r = await client.get("/", params={"paramX": "foo"})
assert r.status_code == 422

r = await client.get("/simple")
assert r.status_code == 200
assert r.json() == {"foo": "bar"}

r = await client.get(
"/list",
params=[("param", "foo"), ("PARAM", "BAR"), ("parAM", "baZ")],
)
assert r.status_code == 200
assert r.json() == {"param": ["foo", "BAR", "baZ"]}
yield client


@pytest.mark.asyncio
async def test_single_query_param_case_insensitive(
client: AsyncClient,
) -> None:
"""Test that single query parameters are handled case-insensitively."""
# Test normal case
r = await client.get("/", params={"param": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}

# Test uppercase parameter
r = await client.get("/", params={"PARAM": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}

# Test mixed case parameter
r = await client.get("/", params={"pARam": "foo"})
assert r.status_code == 200
assert r.json() == {"param": "foo"}


@pytest.mark.asyncio
async def test_query_param_error_handling(client: AsyncClient) -> None:
"""Test error handling for invalid query parameters."""
r = await client.get("/", params={"paramX": "foo"})
assert r.status_code == 422


@pytest.mark.asyncio
async def test_simple_endpoint(client: AsyncClient) -> None:
"""Test endpoint with no parameters."""
r = await client.get("/simple")
assert r.status_code == 200
assert r.json() == {"foo": "bar"}


@pytest.mark.asyncio
async def test_list_query_params_case_insensitive(client: AsyncClient) -> None:
"""Test that list query parameters are handled case-insensitively."""
r = await client.get(
"/list",
params=[("param", "foo"), ("PARAM", "BAR"), ("parAM", "baZ")],
)
assert r.status_code == 200
assert r.json() == {"param": ["foo", "BAR", "baZ"]}


@pytest.mark.asyncio
async def test_form_data_case_insensitive(client: AsyncClient) -> None:
"""Test that form data parameters are handled case-insensitively."""
form_data = {"param": "foo", "PARAM": "BAR", "parAM": "baZ"}
r = await client.post("/form-list", data=form_data)
assert r.status_code == 200
response_data = r.json()
assert response_data["param"] == ["foo", "BAR", "baZ"]
assert all(key == "param" for key in response_data["received_keys"])


@pytest.mark.asyncio
async def test_empty_form_data(client: AsyncClient) -> None:
"""Test that the endpoint handles empty form data gracefully."""
r = await client.post("/form-list", data={})
assert r.status_code == 200
response_data = r.json()
assert response_data["param"] == []
assert response_data["received_keys"] == []

0 comments on commit d2e49a7

Please sign in to comment.