From 4d4da85a36e7b3d8fbe9ea52aecec0d8d7c1a968 Mon Sep 17 00:00:00 2001 From: Ivan Belyaev Date: Mon, 28 Oct 2024 00:00:32 +0300 Subject: [PATCH] Make injecting more simpler (#12) --- src/injection/inject.py | 105 ++++++++---------- tests/integration/test_fastapi/handlers.py | 9 +- .../test_fastapi/test_integration.py | 27 ++++- tests/test_inject.py | 22 ---- tests/test_provide.py | 2 +- 5 files changed, 82 insertions(+), 83 deletions(-) delete mode 100644 tests/test_inject.py diff --git a/src/injection/inject.py b/src/injection/inject.py index 39f009a..0f6aee6 100644 --- a/src/injection/inject.py +++ b/src/injection/inject.py @@ -1,49 +1,44 @@ import inspect import sys from functools import wraps -from typing import Any, Callable, Coroutine, Dict, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, TypeVar, cast from injection.provide import Provide -from injection.providers.base import BaseProvider if sys.version_info < (3, 10): from typing_extensions import ParamSpec else: from typing import ParamSpec -if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin -else: - from typing_extensions import Annotated, get_args, get_origin T = TypeVar("T") P = ParamSpec("P") Markers = Dict[str, Provide] -def _is_fastapi_depends(param: Any) -> bool: - try: - import fastapi - except ImportError: - fastapi = None # type: ignore - return fastapi is not None and isinstance(param, fastapi.params.Depends) - - -def _extract_marker(parameter: inspect.Parameter) -> Union[Any, Provide]: - marker = parameter.default - - parameter_origin = get_origin(parameter.annotation) - is_annotated = parameter_origin is Annotated +# def _is_fastapi_depends(param: Any) -> bool: +# try: +# import fastapi +# except ImportError: +# fastapi = None # type: ignore +# return fastapi is not None and isinstance(param, fastapi.params.Depends) - if is_annotated: - marker = get_args(parameter.annotation)[1] - is_fastapi_depends = _is_fastapi_depends(marker) - - if is_fastapi_depends: - marker = marker.dependency - - return marker +# def _extract_marker(parameter: inspect.Parameter) -> Union[Any, Provide]: +# marker = parameter.default +# +# parameter_origin = get_origin(parameter.annotation) +# is_annotated = parameter_origin is Annotated +# +# if is_annotated: +# marker = get_args(parameter.annotation)[1] +# +# is_fastapi_depends = _is_fastapi_depends(marker) +# +# if is_fastapi_depends: +# marker = marker.dependency +# +# return marker def _get_markers_from_function(f: Callable[P, T]) -> Markers: @@ -52,36 +47,17 @@ def _get_markers_from_function(f: Callable[P, T]) -> Markers: parameters = signature.parameters for parameter_name, parameter_value in parameters.items(): - marker = _extract_marker(parameter_value) - - if not isinstance(marker, Provide): - continue - - injections[parameter_name] = marker + if isinstance(parameter_value.default, Provide): + injections[parameter_name] = parameter_value.default return injections -def _resolve_provide_marker(marker: Provide) -> BaseProvider[Any]: - if not isinstance(marker, Provide): - msg = f"Incorrect marker type: {type(marker)!r}. Marker must be either Provide." - raise TypeError(msg) - - marker_provider = marker.provider - - if not isinstance(marker_provider, BaseProvider): - msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either BaseProvider." - raise TypeError(msg) - - return marker_provider - - -def _extract_provider_values_from_markers(markers: Markers) -> Dict[str, Any]: +def _resolve_markers(markers: Markers) -> Dict[str, Any]: providers = {} - for param, marker_or_str in markers.items(): - provider = _resolve_provide_marker(marker_or_str) - provider_value = provider() + for param, provide in markers.items(): + provider_value = provide.provider() providers[param] = provider_value return providers @@ -93,18 +69,35 @@ def _get_async_injected( ) -> Callable[P, Coroutine[Any, Any, T]]: @wraps(f) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - providers = _extract_provider_values_from_markers(markers) - kwargs.update(providers) + kwargs_for_resolve = { + k: v + for k, v in kwargs.items() + if k not in markers and isinstance(v, Provide) + } + markers.update(kwargs_for_resolve) + + kwarg_values = _resolve_markers(markers) + kwargs.update(kwarg_values) return await f(*args, **kwargs) return wrapper -def _get_sync_injected(f: Callable[P, T], markers: Markers) -> Callable[P, T]: +def _get_sync_injected( + f: Callable[P, T], + markers: Markers, +) -> Callable[P, T]: @wraps(f) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - providers = _extract_provider_values_from_markers(markers) - kwargs.update(providers) + kwargs_for_resolve = { + k: v + for k, v in kwargs.items() + if k not in markers and isinstance(v, Provide) + } + markers.update(kwargs_for_resolve) + + kwarg_values = _resolve_markers(markers) + kwargs.update(kwarg_values) return f(*args, **kwargs) return wrapper diff --git a/tests/integration/test_fastapi/handlers.py b/tests/integration/test_fastapi/handlers.py index b2558d9..32bed8e 100644 --- a/tests/integration/test_fastapi/handlers.py +++ b/tests/integration/test_fastapi/handlers.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/api") RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])] -RedisDependencyWithOnlyProvider = Annotated[Redis, Depends(Container.redis)] +RedisDependencyWithoutProvideMarker = Annotated[Redis, Depends(Container.redis)] @router.get("/values") @@ -28,3 +28,10 @@ def some_get_endpoint_handler(redis: RedisDependency): async def some_get_async_endpoint_handler(redis: RedisDependency): value = redis.get(399) return {"detail": value} + + +@router.post("/values-without_provide") +@inject +async def async_endpoint_handler_with_dep_without_provide( + _: RedisDependencyWithoutProvideMarker, +): ... diff --git a/tests/integration/test_fastapi/test_integration.py b/tests/integration/test_fastapi/test_integration.py index ecdf6e6..c81c042 100644 --- a/tests/integration/test_fastapi/test_integration.py +++ b/tests/integration/test_fastapi/test_integration.py @@ -17,21 +17,42 @@ def test_client(app): return client -def test_sync_endpoint_with_str_provide_marker(test_client): +def test_sync_fastapi_endpoint(test_client): response = test_client.get("/api/values") + assert response.status_code == 200 body = response.json() assert body["detail"] == 299 -def test_async_endpoint_with_explicit_provide_marker(test_client): +def test_async_fastapi_endpoint(test_client): response = test_client.post("/api/values") + assert response.status_code == 200 body = response.json() assert body["detail"] == 399 -def test_overriden_pro(test_client, container): +def test_async_fastapi_endpoint_expect_422_without_provide_marker(test_client): + response = test_client.post("/api/values-without_provide") + + assert response.status_code == 422 + body = response.json() + assert body["detail"][0] == { + "input": None, + "loc": ["query", "args"], + "msg": "Field required", + "type": "missing", + } + assert body["detail"][1] == { + "input": None, + "loc": ["query", "kwargs"], + "msg": "Field required", + "type": "missing", + } + + +def test_fastapi_override_provider(test_client, container): def mock_get_method(_): return "mock_get_method" diff --git a/tests/test_inject.py b/tests/test_inject.py deleted file mode 100644 index 125f480..0000000 --- a/tests/test_inject.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from injection import Provide -from injection.inject import _resolve_provide_marker - - -def test_resolve_provide_marker_fail_when_marker_is_not_provide_type(container): - with pytest.raises(Exception) as e: - _resolve_provide_marker(container.redis) - - assert ( - e.value.args[0] - == f"Incorrect marker type: {type(container.redis)!r}. Marker must be either Provide." - ) - - -def test_resolve_provide_marker_fail_when_marker_parameter_has_incorrect_type(): - with pytest.raises(Exception) as e: - _resolve_provide_marker(Provide[object]) - - error_msg = f"Incorrect marker type: {type(object)!r}. Marker parameter must be either BaseProvider." - assert e.value.args[0] == error_msg diff --git a/tests/test_provide.py b/tests/test_provide.py index 83630ae..7c433de 100644 --- a/tests/test_provide.py +++ b/tests/test_provide.py @@ -6,5 +6,5 @@ @pytest.mark.parametrize("value", [(1,), (object,)]) def test_provide_returns_self(value): provide = Provide[value] - assert provide is provide() + assert isinstance(provide, Provide) assert provide.provider is value