Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for typing.Annotated #721

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/wiring/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from typing import Annotated


class Service:
Expand All @@ -12,12 +13,18 @@ class Container(containers.DeclarativeContainer):

service = providers.Factory(Service)


# You can place marker on parameter default value
@inject
def main(service: Service = Provide[Container.service]) -> None:
...


# Also, you can place marker with typing.Annotated
@inject
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None:
...


if __name__ == "__main__":
container = Container()
container.wire(modules=[__name__])
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ numpy
scipy
boto3
mypy_boto3_s3
typing_extensions

-r requirements-ext.txt
45 changes: 35 additions & 10 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class GenericMeta(type):
else:
GenericAlias = None

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
try:
from typing_extensions import Annotated, get_args, get_origin
except ImportError:
Annotated = object()

# For preventing NameError. Never executes
def get_args(hint):
return ()

def get_origin(tp):
return None

try:
import fastapi.params
Expand Down Expand Up @@ -548,6 +562,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
setattr(patched.member, patched.name, patched.marker)


def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
if get_origin(parameter.annotation) is Annotated:
marker = get_args(parameter.annotation)[1]
else:
marker = parameter.default

if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
return None

if _is_fastapi_depends(marker):
marker = marker.dependency

if not isinstance(marker, _Marker):
return None

return marker


def _fetch_reference_injections( # noqa: C901
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand All @@ -573,17 +605,10 @@ def _fetch_reference_injections( # noqa: C901
injections = {}
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker) \
and not _is_fastapi_depends(parameter.default):
continue
marker = _extract_marker(parameter)

marker = parameter.default

if _is_fastapi_depends(marker):
marker = marker.dependency

if not isinstance(marker, _Marker):
continue
if marker is None:
continue

if isinstance(marker, Closing):
marker = marker.provider
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/samples/wiringfastapi/web.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys

from typing_extensions import Annotated

from fastapi import FastAPI, Depends
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi.security import HTTPBasic, HTTPBasicCredentials
Expand Down Expand Up @@ -27,6 +29,11 @@ async def index(service: Service = Depends(Provide[Container.service])):
result = await service.process()
return {"result": result}

@app.api_route('/annotated')
@inject
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
result = await service.process()
return {'result': result}

@app.get("/auth")
@inject
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/samples/wiringflask/web.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing_extensions import Annotated

from flask import Flask, jsonify, request, current_app, session, g
from flask import _request_ctx_stack, _app_ctx_stack
from dependency_injector import containers, providers
Expand Down Expand Up @@ -28,5 +30,12 @@ def index(service: Service = Provide[Container.service]):
return jsonify({"result": result})


@app.route("/annotated")
@inject
def annotated(service: Annotated[Service, Provide[Container.service]]):
result = service.process()
return jsonify({'result': result})


container = Container()
container.wire(modules=[__name__])
14 changes: 14 additions & 0 deletions tests/unit/wiring/test_fastapi_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ async def process(self):
assert response.json() == {"result": "Foo"}


@mark.asyncio
async def test_depends_with_annotated(async_client: AsyncClient):
class ServiceMock:
async def process(self):
return "Foo"

with web.container.service.override(ServiceMock()):
response = await async_client.get("/")

assert response.status_code == 200
assert response.json() == {"result": "Foo"}



@mark.asyncio
async def test_depends_injection(async_client: AsyncClient):
response = await async_client.get("/auth", auth=("john_smith", "secret"))
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/wiring/test_flask_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ def test_wiring_with_flask():

assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}


def test_wiring_with_annotated():
client = web.app.test_client()

with web.app.app_context():
response = client.get("/annotated")

assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}