From 59f2290df9a939fa43e1ccaca46f3975e7f79b7d Mon Sep 17 00:00:00 2001 From: Aldo Mateli Date: Tue, 2 Apr 2024 23:06:38 +0100 Subject: [PATCH] Add support for wireup params in fastapi views --- test/integration/test_fastapi_integration.py | 12 +++++++ test/unit/test_container.py | 11 ------ wireup/annotation/__init__.py | 36 ++++++++++++-------- wireup/ioc/util.py | 22 +++++------- 4 files changed, 42 insertions(+), 39 deletions(-) diff --git a/test/integration/test_fastapi_integration.py b/test/integration/test_fastapi_integration.py index 4afe5a9..7a20a3b 100644 --- a/test/integration/test_fastapi_integration.py +++ b/test/integration/test_fastapi_integration.py @@ -33,6 +33,18 @@ async def target( self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"number": 4, "lucky_number": 42}) + def test_injects_parameters(self): + self.container.params.put("foo", "bar") + + @self.app.get("/") + async def target(foo: Annotated[str, Wire(param="foo")], foo_foo: Annotated[str, Wire(expr="${foo}-${foo}")]): + return {"foo": foo, "foo_foo": foo_foo} + + wireup_init_fastapi_integration(self.app, dependency_container=self.container, service_modules=[]) + response = self.client.get("/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"foo": "bar", "foo_foo": "bar-bar"}) + def test_raises_on_unknown_service(self): @self.app.get("/") async def target(_unknown_service: Annotated[unittest.TestCase, Wire()]): diff --git a/test/unit/test_container.py b/test/unit/test_container.py index af6c923..a54917f 100644 --- a/test/unit/test_container.py +++ b/test/unit/test_container.py @@ -72,17 +72,6 @@ def __init__( self.assertEqual(svc.connection_str, "sqlite://memory") self.assertEqual(svc.cache_dir, "/var/cache/etc") - def test_inject_param(self): - result = wire(param="value") - self.assertIsInstance(result, ParameterWrapper) - self.assertEqual(result.param, "value") - - def test_inject_expr(self): - result = wire(expr="some ${param}") - self.assertIsInstance(result, ParameterWrapper) - self.assertIsInstance(result.param, TemplatedString) - self.assertEqual(result.param.value, "some ${param}") - @patch("importlib.import_module") def test_inject_fastapi_dep(self, mock_import_module): mock_import_module.return_value = Mock(Depends=Mock()) diff --git a/wireup/annotation/__init__.py b/wireup/annotation/__init__.py index ac89dcf..659e0c8 100644 --- a/wireup/annotation/__init__.py +++ b/wireup/annotation/__init__.py @@ -1,24 +1,29 @@ from __future__ import annotations +import contextlib import importlib from enum import Enum -from typing import Any +from typing import TYPE_CHECKING from wireup.ioc.types import ( ContainerProxyQualifier, ContainerProxyQualifierValue, EmptyContainerInjectionRequest, + InjectableType, ParameterWrapper, TemplatedString, ) +if TYPE_CHECKING: + from collections.abc import Callable + def wire( *, param: str | None = None, expr: str | None = None, qualifier: ContainerProxyQualifierValue = None, -) -> Any: +) -> InjectableType | Callable[[], InjectableType]: """Inject resources from the container to autowired method arguments. Arguments are exclusive and only one of them must be used at any time. @@ -32,21 +37,22 @@ def wire( :param qualifier: Qualify which implementation to bind when there are multiple components implementing an interface that is registered in the container via `@abstract`. """ - if param: - return ParameterWrapper(param) + res: InjectableType | None = None - if expr: - return ParameterWrapper(TemplatedString(expr)) + if param: + res = ParameterWrapper(param) + elif expr: + res = ParameterWrapper(TemplatedString(expr)) + elif qualifier: + res = ContainerProxyQualifier(qualifier) + else: + res = EmptyContainerInjectionRequest() - if qualifier: - return ContainerProxyQualifier(qualifier) + # Fastapi needs all dependencies to be wrapped with Depends. + with contextlib.suppress(ModuleNotFoundError): + return importlib.import_module("fastapi").Depends(lambda: res) # type: ignore[no-any-return] - try: - # Allow fastapi users to do .get() without any params - # It is meant to be used as a default value in where Depends() is expected - return importlib.import_module("fastapi").Depends(EmptyContainerInjectionRequest) - except ModuleNotFoundError: - return EmptyContainerInjectionRequest() + return res class ParameterEnum(Enum): @@ -59,7 +65,7 @@ class ParameterEnum(Enum): This will inject a parameter by name and won't work with expressions. """ - def wire(self) -> Any: + def wire(self) -> InjectableType | Callable[[], InjectableType]: """Inject the parameter this enumeration member represents. Equivalent of `wire(param=EnumParam.enum_member.value)` diff --git a/wireup/ioc/util.py b/wireup/ioc/util.py index 0518c57..6ec8d19 100644 --- a/wireup/ioc/util.py +++ b/wireup/ioc/util.py @@ -4,7 +4,7 @@ from inspect import Parameter from typing import Any -from wireup.ioc.types import AnnotatedParameter, EmptyContainerInjectionRequest, InjectableType +from wireup.ioc.types import AnnotatedParameter, InjectableType def parameter_get_type_and_annotation(parameter: Parameter) -> AnnotatedParameter: @@ -13,27 +13,23 @@ def parameter_get_type_and_annotation(parameter: Parameter) -> AnnotatedParamete Returns either the first annotation for an Annotated type or the default value. """ - def map_to_injectable_type(metadata: Any) -> InjectableType | None: - if isinstance(metadata, InjectableType): - return metadata + def get_injectable_type(metadata: Any) -> InjectableType | None: + # When using fastapi the injectable type will be wrapped with Depends. + # As such, it needs to be unwrapped in order to get the actual metadata + if str(metadata.__class__) == "": + metadata = metadata.dependency() - if ( - str(metadata.__class__) == "" - and metadata.dependency == EmptyContainerInjectionRequest - ): - return EmptyContainerInjectionRequest() - - return None + return metadata if isinstance(metadata, InjectableType) else None if hasattr(parameter.annotation, "__metadata__") and hasattr(parameter.annotation, "__args__"): klass = parameter.annotation.__args__[0] annotation = next( - (map_to_injectable_type(ann) for ann in parameter.annotation.__metadata__ if map_to_injectable_type(ann)), + (get_injectable_type(ann) for ann in parameter.annotation.__metadata__ if get_injectable_type(ann)), None, ) else: klass = None if parameter.annotation is Parameter.empty else parameter.annotation - annotation = None if parameter.default is Parameter.empty else parameter.default + annotation = None if parameter.default is Parameter.empty else get_injectable_type(parameter.default) return AnnotatedParameter(klass=klass, annotation=annotation)