From 689db10e11b3c03edaedbb759f7209019feb1d18 Mon Sep 17 00:00:00 2001 From: Ivan Belyaev Date: Sun, 17 Nov 2024 13:28:00 +0300 Subject: [PATCH] Auto injection feature (#13) * auto inject implementation * remove unused code * fix bug with call with args in Object provider * fix litestar tests --------- Co-authored-by: ivan --- README.md | 41 +++-- .../litestart.md | 2 - docs/providers/partial_callable.md | 25 --- pyproject.toml | 2 +- src/injection/__init__.py | 3 +- src/injection/base_container.py | 28 +++- src/injection/inject/__init__.py | 0 src/injection/inject/auto_inject.py | 151 ++++++++++++++++++ src/injection/inject/exceptions.py | 7 + src/injection/{ => inject}/inject.py | 58 ++----- src/injection/providers/__init__.py | 2 - src/injection/providers/base_factory.py | 20 +-- src/injection/providers/coroutine.py | 2 +- src/injection/providers/object.py | 7 +- src/injection/providers/partial_callable.py | 27 ---- src/injection/providers/singleton.py | 12 +- tests/container_objects.py | 51 ++++-- .../test_drf/drf_test_project/settings.py | 1 - .../test_drf/drf_test_project/views.py | 11 +- .../integration/test_drf/test_integration.py | 20 +++ tests/integration/test_fastapi/handlers.py | 12 +- .../test_fastapi/test_integration.py | 30 ++-- .../test_flask/test_integration.py | 50 +++++- .../test_litestar/test_integration.py | 75 ++++++--- tests/test_auto_inject.py | 66 ++++++++ tests/test_base_container.py | 19 +++ tests/test_e2e.py | 30 +++- tests/test_providers/test_partial_callable.py | 32 ---- 28 files changed, 541 insertions(+), 243 deletions(-) delete mode 100644 docs/integration-with-web-frameworks/litestart.md delete mode 100644 docs/providers/partial_callable.md create mode 100644 src/injection/inject/__init__.py create mode 100644 src/injection/inject/auto_inject.py create mode 100644 src/injection/inject/exceptions.py rename src/injection/{ => inject}/inject.py (54%) delete mode 100644 src/injection/providers/partial_callable.py create mode 100644 tests/test_auto_inject.py delete mode 100644 tests/test_providers/test_partial_callable.py diff --git a/README.md b/README.md index 9e6d83a..6f1407c 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,14 @@ Easy dependency injection for all, works with Python 3.8-3.12. Main features and advantages: * support **Python 3.8-3.12**; -* works with **FastAPI, Flask, Litestar** and **Django REST Framework**; +* works with **FastAPI, Flask** and **Django REST Framework**; * support dependency injection via `Annotated` in `FastAPI`; * the code is fully typed and checked with [mypy](https://github.com/python/mypy); * **no third-party dependencies**; -* **multiple containers**; -* **overriding** dependencies for tests without wiring; -* **100%** code coverage and very simple code; +* no wiring; +* the life cycle of objects (**scope**) is implemented by providers; +* **overriding** dependencies for testing; +* **100%** code coverage; * good [documentation](https://injection.readthedocs.io/latest/); * intuitive and almost identical api with [dependency-injector](https://github.com/ets-labs/python-dependency-injector), which will allow you to easily migrate to injection @@ -45,15 +46,18 @@ which will allow you to easily migrate to injection pip install deps-injection ``` -## Using example +## Compatibility between web frameworks and injection features +| Framework | Dependency injection with @inject | Dependency injection with @autoinject (_experimental_) | Overriding providers | +|--------------------------------------------------------------------------|:---------------------------------:|:------------------------------------------------------:|:--------------------:| +| [FastAPI](https://github.com/fastapi/fastapi) | ✅ | ➖ | ✅ | +| [Flask](https://github.com/pallets/flask) | ✅ | ✅ | ✅ | +| [Django REST Framework](https://github.com/encode/django-rest-framework) | ✅ | ✅ | ✅ | +| [Litestar](https://github.com/litestar-org/litestar) | ➖ | ➖ | ➖ | -```python3 -import sys -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing import Annotated +## Using example with FastAPI +```python3 +from typing import Annotated from unittest.mock import Mock import pytest @@ -103,16 +107,10 @@ RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])] def some_get_endpoint_handler(redis: RedisDependency): value = redis.get(299) return {"detail": value} +``` - -@router.post("/values") -@inject -async def some_get_async_endpoint_handler(redis: RedisDependency): - value = redis.get(399) - return {"detail": value} - - -###################### TESTING ###################### +## Testing example with overriding providers for above FastAPI example +```python3 @pytest.fixture(scope="session") def app(): return create_app() @@ -144,7 +142,4 @@ def test_override_providers(test_client, container): assert response.status_code == 200 body = response.json() assert body["detail"] == "mock_get_method" - ``` - ---- diff --git a/docs/integration-with-web-frameworks/litestart.md b/docs/integration-with-web-frameworks/litestart.md deleted file mode 100644 index e413b59..0000000 --- a/docs/integration-with-web-frameworks/litestart.md +++ /dev/null @@ -1,2 +0,0 @@ -# Litestar -soon.. diff --git a/docs/providers/partial_callable.md b/docs/providers/partial_callable.md deleted file mode 100644 index 7c951e3..0000000 --- a/docs/providers/partial_callable.md +++ /dev/null @@ -1,25 +0,0 @@ -# Partial callable - -**Partial callable** it is a provider very similar in principle to the mechanics of the callable provider, -but works a bit differently and allows for flexibility in some specific situations. - -Let's imagine that we are in a case when we know some values of parameters of a future object right now, -but some values will be known in the future (in runtime). -Obviously, resolving such a provider will not make any sense until the values of all its parameters are known. -But this provider allows you to flexibly fix the values of previously known parameters of the future object, -and the remaining values can be passed at the moment when these values are determined. - -## Example -```python3 -from injection.providers import PartialCallable - - -def some_function(a: str, b: int, *, c: str): - return a, b, c - - -if __name__ == '__main__': - provider = PartialCallable(some_function, 1, 99) - callable_object = provider() - assert callable_object(c=5) == (1, 99, 5) -``` diff --git a/pyproject.toml b/pyproject.toml index 7fd5050..f64d84f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" cov-combine = "coverage combine" cov-report = [ "coverage xml", - "coverage report" + "coverage report -m" ] [[tool.hatch.envs.hatch-test.matrix]] diff --git a/src/injection/__init__.py b/src/injection/__init__.py index 15b39df..ae76a81 100644 --- a/src/injection/__init__.py +++ b/src/injection/__init__.py @@ -1,5 +1,6 @@ from injection import providers from injection.__version__ import __version__ from injection.base_container import DeclarativeContainer -from injection.inject import inject +from injection.inject.auto_inject import auto_inject +from injection.inject.inject import inject from injection.provide import Provide diff --git a/src/injection/base_container.py b/src/injection/base_container.py index 2b2b48f..1e8848a 100644 --- a/src/injection/base_container.py +++ b/src/injection/base_container.py @@ -1,9 +1,12 @@ import inspect +from collections import defaultdict from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast +from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError from injection.providers import Singleton from injection.providers.base import BaseProvider +from injection.providers.base_factory import BaseFactoryProvider F = TypeVar("F", bound=Callable[..., Any]) @@ -101,3 +104,26 @@ def reset_override(cls) -> None: for provider in providers.values(): provider.reset_override() + + @classmethod + def resolve_by_type(cls, type_: Type[Any]) -> Any: + provider_factory_to_providers = defaultdict(list) + + for provider in cls._get_providers_generator(): + if not issubclass(type(provider), BaseFactoryProvider): + continue + + provider_factory_to_providers[provider.factory].append(provider) # type: ignore + + if len(provider_factory_to_providers[provider.factory]) > 1: # type: ignore + raise DuplicatedFactoryTypeAutoInjectionError(str(type_)) + + for providers in provider_factory_to_providers.values(): + provider = providers[0] + provider = cast(BaseFactoryProvider[Any], provider) + + if type_ is provider.factory: + return provider() + + msg = f"Provider with type {type_!s} not found" + raise Exception(msg) diff --git a/src/injection/inject/__init__.py b/src/injection/inject/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/injection/inject/auto_inject.py b/src/injection/inject/auto_inject.py new file mode 100644 index 0000000..9a0d994 --- /dev/null +++ b/src/injection/inject/auto_inject.py @@ -0,0 +1,151 @@ +import inspect +import sys +from functools import wraps +from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar, Union, cast + +from injection.base_container import DeclarativeContainer +from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError +from injection.inject.inject import _resolve_markers +from injection.provide import Provide + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") +Markers = Dict[str, Provide] +_ContainerType = Union[Type[DeclarativeContainer], DeclarativeContainer] + + +def _resolve_signature_args_with_types_from_container( + *, + signature: inspect.Signature, + target_container: _ContainerType, +) -> Dict[str, Any]: + resolved_signature_typed_args = {} + + for param_name, param in signature.parameters.items(): + if not (param.annotation is not param.empty and param.default is param.empty): + continue + + try: + resolved = target_container.resolve_by_type(param.annotation) + resolved_signature_typed_args[param_name] = resolved + except DuplicatedFactoryTypeAutoInjectionError: + raise + + # Ignore exceptions for cases for example django rest framework + # endpoint may have parameter 'request' - we don't know how to handle a variety of parameters. + # But anyway, after this the runtime will fail with an error if something goes wrong + except Exception: # noqa: S112 + continue + + return resolved_signature_typed_args + + +def _get_sync_injected( + *, + f: Callable[P, T], + markers: Markers, + signature: inspect.Signature, + target_container: _ContainerType, +) -> Callable[P, T]: + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + resolved_signature_typed_args = ( + _resolve_signature_args_with_types_from_container( + signature=signature, + target_container=target_container, + ) + ) + + provide_markers = { + k: v + for k, v in kwargs.items() + if k not in markers and isinstance(v, Provide) + } + provide_markers.update(markers) + resolved_values = _resolve_markers(provide_markers) + + kwargs.update(resolved_values) + kwargs.update(resolved_signature_typed_args) + return f(*args, **kwargs) + + return wrapper + + +def _get_async_injected( + *, + f: Callable[P, Coroutine[Any, Any, T]], + markers: Markers, + signature: inspect.Signature, + target_container: _ContainerType, +) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(f) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + resolved_signature_typed_args = ( + _resolve_signature_args_with_types_from_container( + signature=signature, + target_container=target_container, + ) + ) + + provide_markers = { + k: v + for k, v in kwargs.items() + if k not in markers and isinstance(v, Provide) + } + provide_markers.update(markers) + resolved_values = _resolve_markers(provide_markers) + + kwargs.update(resolved_values) + kwargs.update(resolved_signature_typed_args) + return await f(*args, **kwargs) + + return wrapper + + +def auto_inject( + f: Callable[P, T], + target_container: Optional[_ContainerType] = None, +) -> Callable[P, T]: + """Decorate callable with injecting decorator. Inject objects by types""" + + if target_container is None: + container_subclasses = DeclarativeContainer.__subclasses__() + + if len(container_subclasses) > 1: + msg = ( + f"Found {len(container_subclasses)} containers, please specify " + f"the required container explicitly in the parameter 'target_container'" + ) + raise Exception(msg) + + target_container = container_subclasses[0] + + signature = inspect.signature(f) + parameters = signature.parameters + + markers = { + parameter_name: parameter_value.default + for parameter_name, parameter_value in parameters.items() + if isinstance(parameter_value.default, Provide) + } + + if inspect.iscoroutinefunction(f): + func_with_injected_params = _get_async_injected( + f=f, + markers=markers, + signature=signature, + target_container=target_container, + ) + return cast(Callable[P, T], func_with_injected_params) + + return _get_sync_injected( + f=f, + markers=markers, + signature=signature, + target_container=target_container, + ) diff --git a/src/injection/inject/exceptions.py b/src/injection/inject/exceptions.py new file mode 100644 index 0000000..6b02e60 --- /dev/null +++ b/src/injection/inject/exceptions.py @@ -0,0 +1,7 @@ +class DuplicatedFactoryTypeAutoInjectionError(Exception): + def __init__(self, type_: str) -> None: + message = ( + f"Cannot resolve auto inject because found " + f"more than one provider for type '{type_}'" + ) + super().__init__(message) diff --git a/src/injection/inject.py b/src/injection/inject/inject.py similarity index 54% rename from src/injection/inject.py rename to src/injection/inject/inject.py index 0f6aee6..aa2b6d7 100644 --- a/src/injection/inject.py +++ b/src/injection/inject/inject.py @@ -16,51 +16,21 @@ Markers = Dict[str, Provide] -# def _is_fastapi_depends(param: Any) -> bool: -# try: -# import fastapi -# except ImportError: -# fastapi = None # type: ignore -# return fastapi is not None and isinstance(param, fastapi.params.Depends) - - -# def _extract_marker(parameter: inspect.Parameter) -> Union[Any, Provide]: -# marker = parameter.default -# -# parameter_origin = get_origin(parameter.annotation) -# is_annotated = parameter_origin is Annotated -# -# if is_annotated: -# marker = get_args(parameter.annotation)[1] -# -# is_fastapi_depends = _is_fastapi_depends(marker) -# -# if is_fastapi_depends: -# marker = marker.dependency -# -# return marker - - def _get_markers_from_function(f: Callable[P, T]) -> Markers: - injections = {} signature = inspect.signature(f) parameters = signature.parameters - for parameter_name, parameter_value in parameters.items(): - if isinstance(parameter_value.default, Provide): - injections[parameter_name] = parameter_value.default + injections = { + parameter_name: parameter_value.default + for parameter_name, parameter_value in parameters.items() + if isinstance(parameter_value.default, Provide) + } return injections def _resolve_markers(markers: Markers) -> Dict[str, Any]: - providers = {} - - for param, provide in markers.items(): - provider_value = provide.provider() - providers[param] = provider_value - - return providers + return {param: provide.provider() for param, provide in markers.items()} def _get_async_injected( @@ -69,15 +39,15 @@ def _get_async_injected( ) -> Callable[P, Coroutine[Any, Any, T]]: @wraps(f) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - kwargs_for_resolve = { + provide_markers = { k: v for k, v in kwargs.items() if k not in markers and isinstance(v, Provide) } - markers.update(kwargs_for_resolve) + provide_markers.update(markers) + resolved_values = _resolve_markers(provide_markers) - kwarg_values = _resolve_markers(markers) - kwargs.update(kwarg_values) + kwargs.update(resolved_values) return await f(*args, **kwargs) return wrapper @@ -89,15 +59,15 @@ def _get_sync_injected( ) -> Callable[P, T]: @wraps(f) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - kwargs_for_resolve = { + provide_markers = { k: v for k, v in kwargs.items() if k not in markers and isinstance(v, Provide) } - markers.update(kwargs_for_resolve) + provide_markers.update(markers) + resolved_values = _resolve_markers(provide_markers) - kwarg_values = _resolve_markers(markers) - kwargs.update(kwarg_values) + kwargs.update(resolved_values) return f(*args, **kwargs) return wrapper diff --git a/src/injection/providers/__init__.py b/src/injection/providers/__init__.py index c009470..66d0ac1 100644 --- a/src/injection/providers/__init__.py +++ b/src/injection/providers/__init__.py @@ -1,7 +1,6 @@ from injection.providers.callable import Callable from injection.providers.coroutine import Coroutine from injection.providers.object import Object -from injection.providers.partial_callable import PartialCallable from injection.providers.singleton import Singleton from injection.providers.transient import Transient @@ -9,7 +8,6 @@ "Callable", "Coroutine", "Object", - "PartialCallable", "Singleton", "Transient", ] diff --git a/src/injection/providers/base_factory.py b/src/injection/providers/base_factory.py index 0f7e8a3..21e12d9 100644 --- a/src/injection/providers/base_factory.py +++ b/src/injection/providers/base_factory.py @@ -1,4 +1,3 @@ -import inspect from typing import ( Any, Awaitable, @@ -8,7 +7,6 @@ Tuple, TypeVar, Union, - cast, ) from typing_extensions import ParamSpec @@ -23,7 +21,7 @@ class BaseFactoryProvider(BaseProvider[T]): def __init__( self, - factory: Union[Callable[P, Awaitable[T]], Callable[P, T]], + factory: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -32,6 +30,10 @@ def __init__( self._kwargs = kwargs self._factory = factory + @property + def factory(self) -> Callable[P, T]: + return self._factory # type: ignore + def _get_final_args_and_kwargs( self, *args: Any, @@ -40,14 +42,6 @@ def _get_final_args_and_kwargs( clean_args = get_clean_args(self._args) clean_kwargs = get_clean_kwargs(self._kwargs) - # Common solution for bug when litestar try to add kwargs with name 'args' and 'kwargs' - if len(args) > 0 or len(kwargs) > 0: - type_cls_init_signature = inspect.signature(self._factory) - parameters = type_cls_init_signature.parameters - - args = tuple(arg for arg in args if arg in parameters) - kwargs = {arg: value for arg, value in kwargs.items() if arg in parameters} - final_args: List[Any] = [] final_args.extend(clean_args) final_args.extend(args) @@ -57,11 +51,11 @@ def _get_final_args_and_kwargs( final_kwargs.update(kwargs) return tuple(final_args), final_kwargs - def _resolve(self, *args: Any, **kwargs: Any) -> Union[Callable[P, T], T]: + def _resolve(self, *args: Any, **kwargs: Any) -> Union[T, Awaitable[T]]: """ Positional arguments are appended after Factory positional dependencies. Keyword arguments have the priority over the Factory keyword dependencies with the same name. """ final_args, final_kwargs = self._get_final_args_and_kwargs(*args, **kwargs) - instance = cast(Callable[P, T], self._factory(*final_args, **final_kwargs)) + instance = self._factory(*final_args, **final_kwargs) return instance diff --git a/src/injection/providers/coroutine.py b/src/injection/providers/coroutine.py index d40b332..8f1b604 100644 --- a/src/injection/providers/coroutine.py +++ b/src/injection/providers/coroutine.py @@ -15,7 +15,7 @@ def __init__( *a: P.args, **kw: P.kwargs, ) -> None: - super().__init__(coroutine, *a, **kw) + super().__init__(cast(Callable[P, T], coroutine), *a, **kw) def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[T]: return cast(Awaitable[T], super().__call__(*args, **kwargs)) diff --git a/src/injection/providers/object.py b/src/injection/providers/object.py index 9959101..8387043 100644 --- a/src/injection/providers/object.py +++ b/src/injection/providers/object.py @@ -1,4 +1,4 @@ -from typing import TypeVar, cast +from typing import Any, TypeVar, Union, cast from injection.providers.base import BaseProvider from injection.resolving import resolve_value @@ -14,3 +14,8 @@ def __init__(self, obj: T) -> None: def _resolve(self) -> T: value = cast(T, resolve_value(self._obj)) return value + + def __call__(self) -> Union[T, Any]: + if self._mocks: + return self._mocks[-1] + return self._resolve() diff --git a/src/injection/providers/partial_callable.py b/src/injection/providers/partial_callable.py deleted file mode 100644 index 27d8cc1..0000000 --- a/src/injection/providers/partial_callable.py +++ /dev/null @@ -1,27 +0,0 @@ -from functools import partial -from typing import Any, Callable, TypeVar, cast - -from typing_extensions import ParamSpec - -from injection.providers.base_factory import BaseFactoryProvider - -P = ParamSpec("P") -T = TypeVar("T") - - -class PartialCallable(BaseFactoryProvider[T]): - def __init__( - self, - callable_object: Callable[P, T], - *a: P.args, - **kw: P.kwargs, - ) -> None: - super().__init__(callable_object, *a, **kw) - - def _resolve(self, *args: Any, **kwargs: Any) -> Callable[P, T]: - final_args, final_kwargs = self._get_final_args_and_kwargs(*args, **kwargs) - partial_callable = cast( - Callable[P, T], - partial(self._factory, *final_args, **final_kwargs), - ) - return partial_callable diff --git a/src/injection/providers/singleton.py b/src/injection/providers/singleton.py index 148f64c..32bd5ba 100644 --- a/src/injection/providers/singleton.py +++ b/src/injection/providers/singleton.py @@ -1,17 +1,20 @@ -from typing import Any, Optional, Type, TypeVar, cast +from typing import Any, Callable, Optional, TypeVar, cast + +from typing_extensions import ParamSpec from injection.providers.base_factory import BaseFactoryProvider +P = ParamSpec("P") T = TypeVar("T") class Singleton(BaseFactoryProvider[T]): """Global singleton object created only once""" - def __init__(self, type_cls: Type[T], *a: Any, **kw: Any) -> None: + def __init__(self, type_cls: Callable[P, T], *a: Any, **kw: Any) -> None: super().__init__(type_cls, *a, **kw) - self._instance: Optional[T] = None self.type_cls = type_cls + self._instance: Optional[T] = None def _resolve(self, *args: Any, **kwargs: Any) -> T: """https://python-dependency-injector.ets-labs.org/providers/factory.html @@ -21,7 +24,8 @@ def _resolve(self, *args: Any, **kwargs: Any) -> T: """ if self._instance is None: - self._instance = cast(T, super()._resolve(*args, **kwargs)) + instance = super()._resolve(*args, **kwargs) + self._instance = cast(T, instance) return self._instance diff --git a/tests/container_objects.py b/tests/container_objects.py index 995e0fd..4c2ae13 100644 --- a/tests/container_objects.py +++ b/tests/container_objects.py @@ -1,7 +1,7 @@ import asyncio from dataclasses import dataclass, field -from injection import DeclarativeContainer, Provide, inject, providers +from injection import DeclarativeContainer, Provide, auto_inject, inject, providers @dataclass @@ -67,13 +67,7 @@ class Container(DeclarativeContainer): some_service = providers.Singleton(SomeService, 1, redis, svc=service) num = providers.Object(settings.provided.nested_settings.some_const) num2 = providers.Object(9402) - partial_callable = providers.PartialCallable(func, 1, c="string", nums=num) callable_obj = providers.Callable(func, 1, c="string2", nums=num, d={"d": 500}) - transient_obj = providers.Transient( - Redis, - port=settings.provided.redis_port, - url=settings.provided.redis_url, - ) coroutine_provider = providers.Coroutine(coroutine, arg1=1, arg2=2) @@ -86,7 +80,6 @@ def func_with_injections( svc1=Provide[Container.service], svc2=Provide[Container.some_service], numms=Provide[Container.num], - partial_callable_param=Provide[Container.partial_callable], ): _ = sfs _ = ddd @@ -96,6 +89,44 @@ def func_with_injections( svc1.do_smth() svc2.do_smth() - partial_callable_result = partial_callable_param(d={"eparam": "eeeee"}) - _ = partial_callable_result + return redis.url + + +@auto_inject +def func_with_auto_injections( + sfs, + redis: Redis, + *, + ddd, + svc1: Service, + svc2: SomeService, +): + _ = sfs + _ = ddd + + redis.get(1) + svc1.do_smth() + svc2.do_smth() + + return redis.url + + +@auto_inject +def func_with_auto_injections_mixed( + sfs, + *, + ddd, + redis: Redis, + svc1: Service, + svc2: SomeService, + numms=Provide[Container.num], +): + _ = sfs + _ = ddd + _ = numms + + redis.get(1) + svc1.do_smth() + svc2.do_smth() + return redis.url diff --git a/tests/integration/test_drf/drf_test_project/settings.py b/tests/integration/test_drf/drf_test_project/settings.py index 73883fa..334cf5d 100644 --- a/tests/integration/test_drf/drf_test_project/settings.py +++ b/tests/integration/test_drf/drf_test_project/settings.py @@ -1,6 +1,5 @@ from pathlib import Path -# os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.integration.test_drf.drf_test_project.settings') BASE_DIR = Path(__file__).resolve().parent SECRET_KEY = "django-insecure-u20tyumpwc)g21=fjy5nl9u@ih!28()dfvwr4%#cigz$ktop@^" # noqa: S105 diff --git a/tests/integration/test_drf/drf_test_project/views.py b/tests/integration/test_drf/drf_test_project/views.py index f7d56dd..6ab697b 100644 --- a/tests/integration/test_drf/drf_test_project/views.py +++ b/tests/integration/test_drf/drf_test_project/views.py @@ -3,8 +3,8 @@ from rest_framework.response import Response from rest_framework.views import APIView -from injection import Provide, inject -from tests.container_objects import Container +from injection import Provide, auto_inject, inject +from tests.container_objects import Container, Redis class PostEndpointBodySerializer(serializers.Serializer): @@ -13,14 +13,15 @@ class PostEndpointBodySerializer(serializers.Serializer): class View(APIView): @inject - def get(self, _: Request, redis=Provide[Container.redis]): + def get(self, _: Request, redis: Redis = Provide[Container.redis]): response_body = {"redis_url": redis.url} return Response(response_body, status=status.HTTP_200_OK) - @inject - def post(self, request: Request, redis=Provide[Container.redis]): + @auto_inject + def post(self, request: Request, redis: Redis): body_serializer = PostEndpointBodySerializer(data=request.data) body_serializer.is_valid() + key = body_serializer.validated_data["key"] response_body = {"redis_key": redis.get(key)} return Response(response_body, status=status.HTTP_201_CREATED) diff --git a/tests/integration/test_drf/test_integration.py b/tests/integration/test_drf/test_integration.py index cad5ce1..9dc20ba 100644 --- a/tests/integration/test_drf/test_integration.py +++ b/tests/integration/test_drf/test_integration.py @@ -1,4 +1,6 @@ import os +from typing import Any +from unittest.mock import Mock import pytest @@ -20,16 +22,34 @@ def test_drf_client(): def test_drf_get_endpoint(test_drf_client): response = test_drf_client.get("http://127.0.0.1:8000/some_view_prefix") response_body = response.json() + assert response.status_code == 200 assert response_body == {"redis_url": "redis://localhost"} def test_drf_post_endpoint(test_drf_client): redis_key = 234214 + response = test_drf_client.post( "http://127.0.0.1:8000/some_view_prefix", data={"key": redis_key}, ) + assert response.status_code == 201 response_body = response.json() assert response_body == {"redis_key": redis_key} + + +@pytest.mark.parametrize( + "override_value", + ["kjgfiyrdi", "o987ytvydut", "-gfd56a`^^~Wyerjg"], +) +def test_drf_override_provider(test_drf_client, container, override_value: Any): + mock_redis = Mock(url=override_value) + + with container.override_providers_kwargs(redis=mock_redis): + response = test_drf_client.get("http://127.0.0.1:8000/some_view_prefix") + + assert response.status_code == 200 + response_body = response.json() + assert response_body == {"redis_url": override_value} diff --git a/tests/integration/test_fastapi/handlers.py b/tests/integration/test_fastapi/handlers.py index 32bed8e..fe0e031 100644 --- a/tests/integration/test_fastapi/handlers.py +++ b/tests/integration/test_fastapi/handlers.py @@ -16,17 +16,17 @@ RedisDependencyWithoutProvideMarker = Annotated[Redis, Depends(Container.redis)] -@router.get("/values") +@router.post("/values") @inject -def some_get_endpoint_handler(redis: RedisDependency): - value = redis.get(299) +async def some_get_async_endpoint_handler(redis: RedisDependency): + value = redis.get(399) return {"detail": value} -@router.post("/values") +@router.get("/values/{param}") @inject -async def some_get_async_endpoint_handler(redis: RedisDependency): - value = redis.get(399) +def some_get_endpoint_handler(redis: RedisDependency, param: int): + value = redis.get(param) return {"detail": value} diff --git a/tests/integration/test_fastapi/test_integration.py b/tests/integration/test_fastapi/test_integration.py index c81c042..6950dc7 100644 --- a/tests/integration/test_fastapi/test_integration.py +++ b/tests/integration/test_fastapi/test_integration.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import Mock import pytest @@ -17,12 +18,16 @@ def test_client(app): return client -def test_sync_fastapi_endpoint(test_client): - response = test_client.get("/api/values") +@pytest.mark.parametrize( + "value", + [1, 2, 5, 10, 245, -34636, 923425], +) +def test_sync_fastapi_endpoint(test_client, value: int): + response = test_client.get(f"/api/values/{value}") assert response.status_code == 200 body = response.json() - assert body["detail"] == 299 + assert body["detail"] == value def test_async_fastapi_endpoint(test_client): @@ -52,18 +57,17 @@ def test_async_fastapi_endpoint_expect_422_without_provide_marker(test_client): } -def test_fastapi_override_provider(test_client, container): - def mock_get_method(_): - return "mock_get_method" - +@pytest.mark.parametrize( + "override_value", + ["mock_get_method_110934", "blsdfmsdfsf", -345627434], +) +def test_fastapi_override_provider(test_client, container, override_value: Any): mock_redis = Mock() - mock_redis.get = mock_get_method - - providers_to_override = {"redis": mock_redis} + mock_redis.get = lambda _: override_value - with container.override_providers(providers_to_override): - response = test_client.get("/api/values") + with container.override_providers_kwargs(redis=mock_redis): + response = test_client.post("/api/values") assert response.status_code == 200 body = response.json() - assert body["detail"] == "mock_get_method" + assert body["detail"] == override_value diff --git a/tests/integration/test_flask/test_integration.py b/tests/integration/test_flask/test_integration.py index dc9c5f3..b904d24 100644 --- a/tests/integration/test_flask/test_integration.py +++ b/tests/integration/test_flask/test_integration.py @@ -1,7 +1,11 @@ +from typing import Any +from unittest.mock import Mock + +import pytest from flask import Flask -from injection import Provide, inject -from tests.container_objects import Container +from injection import Provide, auto_inject, inject +from tests.container_objects import Container, Redis app = Flask(__name__) app.config.update({"TESTING": True}) @@ -9,13 +13,47 @@ @app.route("/some_resource") @inject -def flask_endpoint(redis=Provide[Container.redis]): +def flask_endpoint(redis: Redis = Provide[Container.redis]): + value = redis.get(-900) + return {"detail": value} + + +@app.route("/auto-inject-endpoint", methods=["POST"]) +@auto_inject +def flask_endpoint_auto_inject(redis: Redis): value = redis.get(-900) return {"detail": value} -def test_flask_endpoint(): - client = app.test_client() - response = client.get("/some_resource") +@pytest.fixture +def test_client(): + return app.test_client() + + +def test_flask_endpoint(test_client): + response = test_client.get("/some_resource") + + assert response.status_code == 200 + assert response.json == {"detail": -900} + + +def test_flask_endpoint_auto_inject(test_client): + response = test_client.post("/auto-inject-endpoint") + assert response.status_code == 200 assert response.json == {"detail": -900} + + +@pytest.mark.parametrize( + "override_value", + ["mock_get_method_110934", "blsdfmsdfsf", -345627434], +) +def test_flask_override_provider(test_client, container, override_value: Any): + mock_redis = Mock() + mock_redis.get = lambda _: override_value + + with container.override_providers_kwargs(redis=mock_redis): + response = test_client.post("auto-inject-endpoint") + + assert response.status_code == 200 + assert response.json == {"detail": override_value} diff --git a/tests/integration/test_litestar/test_integration.py b/tests/integration/test_litestar/test_integration.py index 71d2fe8..baaf6f0 100644 --- a/tests/integration/test_litestar/test_integration.py +++ b/tests/integration/test_litestar/test_integration.py @@ -1,42 +1,77 @@ -from typing import Any - -from litestar import Litestar, post +import pytest +from litestar import Litestar, get from litestar.di import Provide from litestar.testing import TestClient -from tests.container_objects import Container +from injection import inject +from tests.container_objects import Container, Redis -@post( +@get( "/some_resource", status_code=200, dependencies={"redis": Provide(Container.redis)}, ) -async def some_litestar_endpoint(redis: Any) -> dict: - value = redis.get(-924) +@inject +async def litestar_endpoint_with_direct_provider_injection(redis: Redis) -> dict: + value = redis.get(800) return {"detail": value} -app = Litestar(route_handlers=[some_litestar_endpoint], pdb_on_exception=True) +@get( + "/num_endpoint", + status_code=200, + dependencies={"num": Provide(Container.num2)}, +) +async def litestar_endpoint_object_provider(num: int) -> dict: + return {"detail": num} + + +_handlers = [ + litestar_endpoint_object_provider, + litestar_endpoint_with_direct_provider_injection, +] +app_deps = { + # "redis": Provide(Container.redis), +} -def test_litestar_endpoint(): +app = Litestar(route_handlers=_handlers, debug=True, dependencies=app_deps) + + +@pytest.mark.xfail( + reason="TypeError: __init__() got an unexpected keyword argument 'args'", +) +def test_litestar_endpoint_with_direct_provider_injection(): with TestClient(app=app) as client: - response = client.post("/some_resource") - assert response.status_code == 200 - assert response.json() == {"detail": -924} + response = client.get("/some_resource") + assert response.status_code == 200 + assert response.json() == {"detail": 800} -def test_litestar_endpoint_with_overriding(): - class RedisMock: - def get(self, _): - return 1119 - mock_redis = RedisMock() +def test_litestar_object_provider(): + with TestClient(app=app) as client: + response = client.get("/num_endpoint") + + assert response.status_code == 200 + assert response.json() == {"detail": 9402} + + +class _RedisMock: + def get(self, _): + return 192342526 + + +@pytest.mark.xfail( + reason="TypeError: Unsupported type: ", +) +def test_litestar_overriding_direct_provider_endpoint(): + mock_instance = _RedisMock() with TestClient(app=app) as client: - with Container.redis.override_context(mock_redis): - response = client.post("/some_resource") + with Container.redis.override_context(mock_instance): + response = client.get("/some_resource") assert response.status_code == 200 - assert response.json() == {"detail": 1119} + assert response.json() == {"detail": 192342526} diff --git a/tests/test_auto_inject.py b/tests/test_auto_inject.py new file mode 100644 index 0000000..09c753c --- /dev/null +++ b/tests/test_auto_inject.py @@ -0,0 +1,66 @@ +from unittest import mock + +import pytest + +from injection import DeclarativeContainer, auto_inject +from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError +from tests.container_objects import Redis, Service, SomeService + + +@pytest.mark.parametrize( + "subclasses", + [ + [object(), object()], + [object(), object(), object()], + ], +) +def test_auto_inject_expect_error_with_more_than_one_di_container_and_empty_target_container_param( + subclasses: list, +): + match = ( + f"Found {len(subclasses)} containers, " + f"please specify the required container explicitly in the parameter 'target_container'" + ) + + with mock.patch.object( + DeclarativeContainer, + "__subclasses__", + return_value=subclasses, + ): + with pytest.raises(Exception, match=match): + auto_inject(lambda: None) + + +@auto_inject +async def _async_func( + redis: Redis, + *, + a: int, + b: str = "asdsd", + svc: Service, + another_svc: SomeService, +): + assert a == 234 + assert b == "rnd" + assert isinstance(redis, Redis) + assert isinstance(svc, Service) + assert isinstance(another_svc, SomeService) + + +async def test_auto_inject_on_async_target(): + await _async_func(a=234, b="rnd") + + +async def test_auto_inject_expect_error_on_duplicated_provider_types(container): + _mock_providers = [container.__dict__["redis"]] + _mock_providers.extend( + list(container._get_providers_generator()), + ) + + with mock.patch.object( + container, + "_get_providers_generator", + return_value=_mock_providers, + ): + with pytest.raises(DuplicatedFactoryTypeAutoInjectionError): + await _async_func(a=234, b="rnd") diff --git a/tests/test_base_container.py b/tests/test_base_container.py index 6911b3a..703a31f 100644 --- a/tests/test_base_container.py +++ b/tests/test_base_container.py @@ -1,10 +1,13 @@ from typing import Generator +from unittest import mock from unittest.mock import Mock import pytest +from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError from injection.providers.base import BaseProvider from injection.providers.singleton import Singleton +from tests.container_objects import Redis def test_get_providers(container): @@ -99,3 +102,19 @@ def test_reset_override(container): assert len(container.num2._mocks) == 0 assert container.num() == original_num_value assert container.num2() == original_num2_value + + +def test_resolve_by_type_expect_error_on_duplicated_provider_types(container): + # Simulate a duplicate 'redis' provider + _mock_providers = [container.__dict__["redis"]] + _mock_providers.extend( + list(container._get_providers_generator()), + ) + + with mock.patch.object( + container, + "_get_providers_generator", + return_value=_mock_providers, + ): + with pytest.raises(DuplicatedFactoryTypeAutoInjectionError): + container.resolve_by_type(Redis) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index bacd5d1..5d53d23 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -2,7 +2,11 @@ from dataclasses import dataclass from unittest.mock import Mock -from tests.container_objects import Settings, func_with_injections +from tests.container_objects import ( + Settings, + func_with_auto_injections, + func_with_injections, +) def test_e2e_success(container): @@ -21,14 +25,30 @@ def test_e2e_success(container): assert container.service() is not service assert container.num() == 144 - callable_result = container.callable_obj(d="sdf") - partial_callable_result = container.partial_callable(d="sdf")(d="sdfsfwer2") - assert (5555, 144) == callable_result == partial_callable_result - coroutine_result = asyncio.run(container.coroutine_provider()) assert coroutine_result == (1, 2) +def test_e2e_auto_inject_success(container): + redis = container.redis() + assert redis.url == func_with_auto_injections(2, ddd="sfs") + + class MockRedis: + def __init__(self): + self.url = "mock_redis_url" + + def get(self, _): ... + + mock_redis = MockRedis() + + with container.override_providers_kwargs(redis=MockRedis()): + assert mock_redis.url == func_with_auto_injections(224324, ddd="sdfsdfsdf") + + redis = container.redis() + assert redis.url != mock_redis.url + assert redis.url == func_with_auto_injections(2, ddd="sfs") + + def test_e2e_override(container): assert container.redis().url == container.settings().redis_url diff --git a/tests/test_providers/test_partial_callable.py b/tests/test_providers/test_partial_callable.py deleted file mode 100644 index 752dd04..0000000 --- a/tests/test_providers/test_partial_callable.py +++ /dev/null @@ -1,32 +0,0 @@ -import functools - -import pytest - -from injection import providers - - -def some_function(a: str, b: int, *, c: str): - return a, b, c - - -def test_partial_callable_resolving_fail_without_arg(): - kwargs = {"a": "aa", "b": 22} - callable_provider = providers.PartialCallable(some_function, **kwargs) - - with pytest.raises(TypeError) as e: - callable_provider()() - - match = "some_function() missing 1 required keyword-only argument: 'c'" - assert e.value.args[0] == match - - -def test_partial_callable_resolving(): - args = ("aa", 22) - kwargs = {"c": "value"} - provider = providers.PartialCallable(some_function, *args) - resolved_callable = provider(**kwargs) - - assert isinstance(resolved_callable, functools.partial) - assert resolved_callable.args == args - assert resolved_callable.keywords == kwargs - assert tuple([*args, "value"]) == resolved_callable()