Skip to content

Commit

Permalink
Make injecting more simpler (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
nightblure authored Oct 27, 2024
1 parent b41f880 commit 4d4da85
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 83 deletions.
105 changes: 49 additions & 56 deletions src/injection/inject.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,44 @@
import inspect
import sys
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, TypeVar, Union, cast
from typing import Any, Callable, Coroutine, Dict, TypeVar, cast

from injection.provide import Provide
from injection.providers.base import BaseProvider

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
from typing_extensions import Annotated, get_args, get_origin

T = TypeVar("T")
P = ParamSpec("P")
Markers = Dict[str, Provide]


def _is_fastapi_depends(param: Any) -> bool:
try:
import fastapi
except ImportError:
fastapi = None # type: ignore
return fastapi is not None and isinstance(param, fastapi.params.Depends)


def _extract_marker(parameter: inspect.Parameter) -> Union[Any, Provide]:
marker = parameter.default

parameter_origin = get_origin(parameter.annotation)
is_annotated = parameter_origin is Annotated
# def _is_fastapi_depends(param: Any) -> bool:
# try:
# import fastapi
# except ImportError:
# fastapi = None # type: ignore
# return fastapi is not None and isinstance(param, fastapi.params.Depends)

if is_annotated:
marker = get_args(parameter.annotation)[1]

is_fastapi_depends = _is_fastapi_depends(marker)

if is_fastapi_depends:
marker = marker.dependency

return marker
# def _extract_marker(parameter: inspect.Parameter) -> Union[Any, Provide]:
# marker = parameter.default
#
# parameter_origin = get_origin(parameter.annotation)
# is_annotated = parameter_origin is Annotated
#
# if is_annotated:
# marker = get_args(parameter.annotation)[1]
#
# is_fastapi_depends = _is_fastapi_depends(marker)
#
# if is_fastapi_depends:
# marker = marker.dependency
#
# return marker


def _get_markers_from_function(f: Callable[P, T]) -> Markers:
Expand All @@ -52,36 +47,17 @@ def _get_markers_from_function(f: Callable[P, T]) -> Markers:
parameters = signature.parameters

for parameter_name, parameter_value in parameters.items():
marker = _extract_marker(parameter_value)

if not isinstance(marker, Provide):
continue

injections[parameter_name] = marker
if isinstance(parameter_value.default, Provide):
injections[parameter_name] = parameter_value.default

return injections


def _resolve_provide_marker(marker: Provide) -> BaseProvider[Any]:
if not isinstance(marker, Provide):
msg = f"Incorrect marker type: {type(marker)!r}. Marker must be either Provide."
raise TypeError(msg)

marker_provider = marker.provider

if not isinstance(marker_provider, BaseProvider):
msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either BaseProvider."
raise TypeError(msg)

return marker_provider


def _extract_provider_values_from_markers(markers: Markers) -> Dict[str, Any]:
def _resolve_markers(markers: Markers) -> Dict[str, Any]:
providers = {}

for param, marker_or_str in markers.items():
provider = _resolve_provide_marker(marker_or_str)
provider_value = provider()
for param, provide in markers.items():
provider_value = provide.provider()
providers[param] = provider_value

return providers
Expand All @@ -93,18 +69,35 @@ def _get_async_injected(
) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
providers = _extract_provider_values_from_markers(markers)
kwargs.update(providers)
kwargs_for_resolve = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
markers.update(kwargs_for_resolve)

kwarg_values = _resolve_markers(markers)
kwargs.update(kwarg_values)
return await f(*args, **kwargs)

return wrapper


def _get_sync_injected(f: Callable[P, T], markers: Markers) -> Callable[P, T]:
def _get_sync_injected(
f: Callable[P, T],
markers: Markers,
) -> Callable[P, T]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
providers = _extract_provider_values_from_markers(markers)
kwargs.update(providers)
kwargs_for_resolve = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
markers.update(kwargs_for_resolve)

kwarg_values = _resolve_markers(markers)
kwargs.update(kwarg_values)
return f(*args, **kwargs)

return wrapper
Expand Down
9 changes: 8 additions & 1 deletion tests/integration/test_fastapi/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
router = APIRouter(prefix="/api")

RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])]
RedisDependencyWithOnlyProvider = Annotated[Redis, Depends(Container.redis)]
RedisDependencyWithoutProvideMarker = Annotated[Redis, Depends(Container.redis)]


@router.get("/values")
Expand All @@ -28,3 +28,10 @@ def some_get_endpoint_handler(redis: RedisDependency):
async def some_get_async_endpoint_handler(redis: RedisDependency):
value = redis.get(399)
return {"detail": value}


@router.post("/values-without_provide")
@inject
async def async_endpoint_handler_with_dep_without_provide(
_: RedisDependencyWithoutProvideMarker,
): ...
27 changes: 24 additions & 3 deletions tests/integration/test_fastapi/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,42 @@ def test_client(app):
return client


def test_sync_endpoint_with_str_provide_marker(test_client):
def test_sync_fastapi_endpoint(test_client):
response = test_client.get("/api/values")

assert response.status_code == 200
body = response.json()
assert body["detail"] == 299


def test_async_endpoint_with_explicit_provide_marker(test_client):
def test_async_fastapi_endpoint(test_client):
response = test_client.post("/api/values")

assert response.status_code == 200
body = response.json()
assert body["detail"] == 399


def test_overriden_pro(test_client, container):
def test_async_fastapi_endpoint_expect_422_without_provide_marker(test_client):
response = test_client.post("/api/values-without_provide")

assert response.status_code == 422
body = response.json()
assert body["detail"][0] == {
"input": None,
"loc": ["query", "args"],
"msg": "Field required",
"type": "missing",
}
assert body["detail"][1] == {
"input": None,
"loc": ["query", "kwargs"],
"msg": "Field required",
"type": "missing",
}


def test_fastapi_override_provider(test_client, container):
def mock_get_method(_):
return "mock_get_method"

Expand Down
22 changes: 0 additions & 22 deletions tests/test_inject.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_provide.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
@pytest.mark.parametrize("value", [(1,), (object,)])
def test_provide_returns_self(value):
provide = Provide[value]
assert provide is provide()
assert isinstance(provide, Provide)
assert provide.provider is value

0 comments on commit 4d4da85

Please sign in to comment.