Skip to content

Commit

Permalink
Auto injection feature (#13)
Browse files Browse the repository at this point in the history
* auto inject implementation

* remove unused code

* fix bug with call with args in Object provider

* fix litestar tests

---------

Co-authored-by: ivan
  • Loading branch information
nightblure authored Nov 17, 2024
1 parent 47ed678 commit 689db10
Show file tree
Hide file tree
Showing 28 changed files with 541 additions and 243 deletions.
41 changes: 18 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@

Easy dependency injection for all, works with Python 3.8-3.12. Main features and advantages:
* support **Python 3.8-3.12**;
* works with **FastAPI, Flask, Litestar** and **Django REST Framework**;
* works with **FastAPI, Flask** and **Django REST Framework**;
* support dependency injection via `Annotated` in `FastAPI`;
* the code is fully typed and checked with [mypy](https://github.com/python/mypy);
* **no third-party dependencies**;
* **multiple containers**;
* **overriding** dependencies for tests without wiring;
* **100%** code coverage and very simple code;
* no wiring;
* the life cycle of objects (**scope**) is implemented by providers;
* **overriding** dependencies for testing;
* **100%** code coverage;
* good [documentation](https://injection.readthedocs.io/latest/);
* intuitive and almost identical api with [dependency-injector](https://github.com/ets-labs/python-dependency-injector),
which will allow you to easily migrate to injection
Expand All @@ -45,15 +46,18 @@ which will allow you to easily migrate to injection
pip install deps-injection
```

## Using example
## Compatibility between web frameworks and injection features
| Framework | Dependency injection with @inject | Dependency injection with @autoinject (_experimental_) | Overriding providers |
|--------------------------------------------------------------------------|:---------------------------------:|:------------------------------------------------------:|:--------------------:|
| [FastAPI](https://github.com/fastapi/fastapi) ||||
| [Flask](https://github.com/pallets/flask) ||||
| [Django REST Framework](https://github.com/encode/django-rest-framework) ||||
| [Litestar](https://github.com/litestar-org/litestar) ||||

```python3
import sys

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing import Annotated
## Using example with FastAPI
```python3
from typing import Annotated
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -103,16 +107,10 @@ RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])]
def some_get_endpoint_handler(redis: RedisDependency):
value = redis.get(299)
return {"detail": value}
```


@router.post("/values")
@inject
async def some_get_async_endpoint_handler(redis: RedisDependency):
value = redis.get(399)
return {"detail": value}


###################### TESTING ######################
## Testing example with overriding providers for above FastAPI example
```python3
@pytest.fixture(scope="session")
def app():
return create_app()
Expand Down Expand Up @@ -144,7 +142,4 @@ def test_override_providers(test_client, container):
assert response.status_code == 200
body = response.json()
assert body["detail"] == "mock_get_method"

```

---
2 changes: 0 additions & 2 deletions docs/integration-with-web-frameworks/litestart.md

This file was deleted.

25 changes: 0 additions & 25 deletions docs/providers/partial_callable.md

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}"
cov-combine = "coverage combine"
cov-report = [
"coverage xml",
"coverage report"
"coverage report -m"
]

[[tool.hatch.envs.hatch-test.matrix]]
Expand Down
3 changes: 2 additions & 1 deletion src/injection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from injection import providers
from injection.__version__ import __version__
from injection.base_container import DeclarativeContainer
from injection.inject import inject
from injection.inject.auto_inject import auto_inject
from injection.inject.inject import inject
from injection.provide import Provide
28 changes: 27 additions & 1 deletion src/injection/base_container.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import inspect
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast

from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError
from injection.providers import Singleton
from injection.providers.base import BaseProvider
from injection.providers.base_factory import BaseFactoryProvider

F = TypeVar("F", bound=Callable[..., Any])

Expand Down Expand Up @@ -101,3 +104,26 @@ def reset_override(cls) -> None:

for provider in providers.values():
provider.reset_override()

@classmethod
def resolve_by_type(cls, type_: Type[Any]) -> Any:
provider_factory_to_providers = defaultdict(list)

for provider in cls._get_providers_generator():
if not issubclass(type(provider), BaseFactoryProvider):
continue

provider_factory_to_providers[provider.factory].append(provider) # type: ignore

if len(provider_factory_to_providers[provider.factory]) > 1: # type: ignore
raise DuplicatedFactoryTypeAutoInjectionError(str(type_))

for providers in provider_factory_to_providers.values():
provider = providers[0]
provider = cast(BaseFactoryProvider[Any], provider)

if type_ is provider.factory:
return provider()

msg = f"Provider with type {type_!s} not found"
raise Exception(msg)
Empty file.
151 changes: 151 additions & 0 deletions src/injection/inject/auto_inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import inspect
import sys
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar, Union, cast

from injection.base_container import DeclarativeContainer
from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError
from injection.inject.inject import _resolve_markers
from injection.provide import Provide

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")
Markers = Dict[str, Provide]
_ContainerType = Union[Type[DeclarativeContainer], DeclarativeContainer]


def _resolve_signature_args_with_types_from_container(
*,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Dict[str, Any]:
resolved_signature_typed_args = {}

for param_name, param in signature.parameters.items():
if not (param.annotation is not param.empty and param.default is param.empty):
continue

try:
resolved = target_container.resolve_by_type(param.annotation)
resolved_signature_typed_args[param_name] = resolved
except DuplicatedFactoryTypeAutoInjectionError:
raise

# Ignore exceptions for cases for example django rest framework
# endpoint may have parameter 'request' - we don't know how to handle a variety of parameters.
# But anyway, after this the runtime will fail with an error if something goes wrong
except Exception: # noqa: S112
continue

return resolved_signature_typed_args


def _get_sync_injected(
*,
f: Callable[P, T],
markers: Markers,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Callable[P, T]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
resolved_signature_typed_args = (
_resolve_signature_args_with_types_from_container(
signature=signature,
target_container=target_container,
)
)

provide_markers = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
provide_markers.update(markers)
resolved_values = _resolve_markers(provide_markers)

kwargs.update(resolved_values)
kwargs.update(resolved_signature_typed_args)
return f(*args, **kwargs)

return wrapper


def _get_async_injected(
*,
f: Callable[P, Coroutine[Any, Any, T]],
markers: Markers,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
resolved_signature_typed_args = (
_resolve_signature_args_with_types_from_container(
signature=signature,
target_container=target_container,
)
)

provide_markers = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
provide_markers.update(markers)
resolved_values = _resolve_markers(provide_markers)

kwargs.update(resolved_values)
kwargs.update(resolved_signature_typed_args)
return await f(*args, **kwargs)

return wrapper


def auto_inject(
f: Callable[P, T],
target_container: Optional[_ContainerType] = None,
) -> Callable[P, T]:
"""Decorate callable with injecting decorator. Inject objects by types"""

if target_container is None:
container_subclasses = DeclarativeContainer.__subclasses__()

if len(container_subclasses) > 1:
msg = (
f"Found {len(container_subclasses)} containers, please specify "
f"the required container explicitly in the parameter 'target_container'"
)
raise Exception(msg)

target_container = container_subclasses[0]

signature = inspect.signature(f)
parameters = signature.parameters

markers = {
parameter_name: parameter_value.default
for parameter_name, parameter_value in parameters.items()
if isinstance(parameter_value.default, Provide)
}

if inspect.iscoroutinefunction(f):
func_with_injected_params = _get_async_injected(
f=f,
markers=markers,
signature=signature,
target_container=target_container,
)
return cast(Callable[P, T], func_with_injected_params)

return _get_sync_injected(
f=f,
markers=markers,
signature=signature,
target_container=target_container,
)
7 changes: 7 additions & 0 deletions src/injection/inject/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class DuplicatedFactoryTypeAutoInjectionError(Exception):
def __init__(self, type_: str) -> None:
message = (
f"Cannot resolve auto inject because found "
f"more than one provider for type '{type_}'"
)
super().__init__(message)
Loading

0 comments on commit 689db10

Please sign in to comment.