From d1e5c072acd1e94f1a8fa5ad98ab389c0adaaf86 Mon Sep 17 00:00:00 2001 From: Aldo Mateli Date: Sat, 16 Nov 2024 13:11:12 +0000 Subject: [PATCH] Expose fastapi request object as a transient object (#38) * Expose fastapi request object as a transient object * Reset token * Remove unused import * Raise when being accessed outside a request * Update fastapi integration docs * Add test case --- docs/pages/integrations/fastapi.md | 71 ++++++++++++++------ test/integration/test_fastapi_integration.py | 38 ++++++++++- wireup/integration/fastapi.py | 33 ++++++++- wireup/ioc/initialization_context.py | 9 +++ 4 files changed, 127 insertions(+), 24 deletions(-) diff --git a/docs/pages/integrations/fastapi.md b/docs/pages/integrations/fastapi.md index 7f9016b..2978622 100644 --- a/docs/pages/integrations/fastapi.md +++ b/docs/pages/integrations/fastapi.md @@ -1,21 +1,15 @@ Dependency injection for FastAPI is available in the `wireup.integration.fastapi_integration` module. - **Features:** -* Automatically decorate Flask views and blueprints where the container is being used. - * Eliminates the need for `@container.autowire` in views. - * Views without container references will not be decorated. - * Services **must** be annotated with `Inject()`. -* Can: Mix FastAPI dependencies and Wireup in views -* Can: Autowire FastAPI target with `@container.autowire`. -* Cannot: Use FastAPI dependencies in Wireup service objects. - -!!! tip - As FastAPI does not have a fixed configuration mechanism, you need to expose - configuration to the container. See [configuration docs](../configuration.md) for more details. +- [x] Inject dependencies in FastAPI routes. + * Eliminates the need for `@container.autowire`. +- [x] Expose `fastapi.Request` as a wireup dependency. + * Available as a `TRANSIENT` scoped dependency, your services can ask for a fastapi request object. +- [x] Can: Mix Wireup and FastAPI dependencies in routes. +- [ ] Cannot: Use FastAPI dependencies in Wireup service objects. -## Examples +## Getting Started ```python title="main.py" app = FastAPI() @@ -23,20 +17,16 @@ app = FastAPI() @app.get("/random") async def target( # Inject annotation tells wireup that this argument should be injected. + # Inject() annotation is required otherwise fastapi will think it's a pydantic model. random_service: Annotated[RandomService, Inject()], is_debug: Annotated[bool, Inject(param="env.debug")], # This is a regular FastAPI dependency. lucky_number: Annotated[int, Depends(get_lucky_number)] -): - return { - "number": random_service.get_random(), - "lucky_number": lucky_number, - "is_debug": is_debug, - } +): ... # Initialize the integration. -# Must be called after views have been registered. +# Must be called after all routers have been added. # service_modules is a list of top-level modules with service registrations. container = wireup.create_container( service_modules=[services], @@ -45,6 +35,47 @@ container = wireup.create_container( wireup.integration.fastapi.setup(container, app) ``` +Wireup integration performs injection only in fastapi routes. If you're not storing the container in a global variable, +you can always get a reference to it wherever you have a fastapi application reference +by using `wireup.integration.fastapi.get_container`. + +```python title="example_middleware.py" +from wireup.integration.fastapi import get_container + +async def example_middleware(request: Request, call_next) -> Response: + container = get_container(request.app) + ... + + return await call_next(request) +``` + + +In the same way, you can get a reference to it in a fastapi dependency. +```python +from wireup.integration.fastapi import get_container + +async def example_dependency(request: Request, other_dependency: Depends(...)): + container = get_container(request.app) + ... +``` + +### FastAPI request + +A key feature of the integration is to expose `fastapi.Request` and `starlette.requests.Request` objects in wireup. + +Services depending on it should be transient, so that you get a fresh copy +every time with the current request being processed. + +```python +@service(lifetime=ServiceLifetime.TRANSIENT) +class HttpAuthenticationService: + def __init__(self, request: fastapi.Request) -> None: ... + + +@service(lifetime=ServiceLifetime.TRANSIENT) +def example_factory(request: fastapi.Request) -> ExampleService: ... +``` + ### Testing For general testing tips with Wireup refer to the [test docs](../testing.md). diff --git a/test/integration/test_fastapi_integration.py b/test/integration/test_fastapi_integration.py index 99f84e1..e459f4a 100644 --- a/test/integration/test_fastapi_integration.py +++ b/test/integration/test_fastapi_integration.py @@ -1,17 +1,29 @@ +import asyncio +import uuid +from dataclasses import dataclass +from typing import Any, Dict + +import anyio.to_thread import pytest import wireup import wireup.integration import wireup.integration.fastapi -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, Request from fastapi.testclient import TestClient from typing_extensions import Annotated from wireup import Inject -from wireup.errors import UnknownServiceRequestedError +from wireup.errors import UnknownServiceRequestedError, WireupError from wireup.integration.fastapi import get_container +from wireup.ioc.types import ServiceLifetime from test.unit.services.no_annotations.random.random_service import RandomService +@dataclass +class ServiceUsingFastapiRequest: + req: Request + + def get_lucky_number() -> int: # Raise if this will be invoked more than once # That would be the case if wireup also "unwraps" and tries @@ -44,8 +56,13 @@ async def _(foo: Annotated[str, Inject(param="foo")], foo_foo: Annotated[str, In async def _(_unknown_service: Annotated[None, Inject()]): return {"msg": "Hello World"} + @app.get("/current-request") + async def _(_request: Request, req: Annotated[ServiceUsingFastapiRequest, Inject()]) -> Dict[str, Any]: + return {"foo": req.req.query_params["foo"], "request_id": req.req.headers["X-Request-Id"]} + container = wireup.create_container(service_modules=[], parameters={"foo": "bar"}) container.register(RandomService) + container.register(ServiceUsingFastapiRequest, lifetime=ServiceLifetime.TRANSIENT) wireup.integration.fastapi.setup(container, app) return app @@ -84,9 +101,26 @@ def test_injects_parameters(client: TestClient): assert response.json() == {"foo": "bar", "foo_foo": "bar-bar"} +async def test_current_request_service(client: TestClient): + async def _make_request(): + request_id = uuid.uuid4().hex + response = await anyio.to_thread.run_sync( + lambda: client.get("/current-request", params={"foo": request_id}, headers={"X-Request-Id": request_id}) + ) + assert response.status_code == 200 + assert response.json() == {"foo": request_id, "request_id": request_id} + + await asyncio.gather(*(_make_request() for _ in range(100))) + + def test_raises_on_unknown_service(client: TestClient): with pytest.raises( UnknownServiceRequestedError, match="Cannot wire unknown class . Use '@service' or '@abstract' to enable autowiring.", ): client.get("/raise-unknown") + + +def test_raises_request_outside_of_scope(app: FastAPI) -> None: + with pytest.raises(WireupError, match="fastapi.Request in wireup is only available during a request."): + get_container(app).get(Request) diff --git a/wireup/integration/fastapi.py b/wireup/integration/fastapi.py index 3f8a979..5a505ab 100644 --- a/wireup/integration/fastapi.py +++ b/wireup/integration/fastapi.py @@ -1,8 +1,31 @@ -from fastapi import FastAPI +from contextvars import ContextVar +from typing import Awaitable, Callable + +from fastapi import FastAPI, Request, Response from fastapi.routing import APIRoute from wireup import DependencyContainer +from wireup.errors import WireupError from wireup.integration.util import is_view_using_container +from wireup.ioc.types import ServiceLifetime + +current_request: ContextVar[Request] = ContextVar("wireup_fastapi_request") + + +async def _wireup_request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + token = current_request.set(request) + try: + return await call_next(request) + finally: + current_request.reset(token) + + +def _fastapi_request_factory() -> Request: + try: + return current_request.get() + except LookupError as e: + msg = "fastapi.Request in wireup is only available during a request." + raise WireupError(msg) from e def _autowire_views(container: DependencyContainer, app: FastAPI) -> None: @@ -12,7 +35,11 @@ def _autowire_views(container: DependencyContainer, app: FastAPI) -> None: and route.dependant.call and is_view_using_container(container, route.dependant.call) ): - route.dependant.call = container.autowire(route.dependant.call) + target = route.dependant.call + route.dependant.call = container.autowire(target) + # Remove Request as a dependency from this target. + # Let fastapi inject it instead and avoid duplicated work. + container._registry.context.remove_dependency_type(target, Request) # type: ignore[reportPrivateUsage] # noqa: SLF001 def setup(container: DependencyContainer, app: FastAPI) -> None: @@ -20,6 +47,8 @@ def setup(container: DependencyContainer, app: FastAPI) -> None: This will automatically inject dependencies on FastAPI routers. """ + container.register(_fastapi_request_factory, lifetime=ServiceLifetime.TRANSIENT) + app.middleware("http")(_wireup_request_middleware) _autowire_views(container, app) app.state.wireup_container = container diff --git a/wireup/ioc/initialization_context.py b/wireup/ioc/initialization_context.py index 3c1cddc..8f7bc49 100644 --- a/wireup/ioc/initialization_context.py +++ b/wireup/ioc/initialization_context.py @@ -64,3 +64,12 @@ def remove_dependencies(self, target: AutowireTarget, names_to_remove: set[str]) Target must have been already initialized prior to calling this. """ self.__dependencies[target] = {k: v for k, v in self.__dependencies[target].items() if k not in names_to_remove} + + def remove_dependency_type(self, target: AutowireTarget, type_to_remove: type) -> None: + """Remove dependencies with the given type from the target. + + Target must have been already initialized prior to calling this. + """ + self.__dependencies[target] = { + k: v for k, v in self.__dependencies[target].items() if v.klass != type_to_remove + }