Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added resolve_all to container #114

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ where = ["src"]

[project]
name = "dishka"
version = "0.1"
version = "0.8.0"
readme = "README.md"
authors = [
{ name = "Andrey Tikhonov", email = "[email protected]" },
Expand Down
43 changes: 42 additions & 1 deletion src/dishka/async_container.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from asyncio import Lock
from collections.abc import Callable
from typing import Any, Optional, TypeVar
from contextlib import suppress
from typing import Any, Iterable, Literal, Optional, TypeVar, overload

from dishka.entities.component import DEFAULT_COMPONENT, Component
from dishka.entities.key import DependencyKey
Expand All @@ -9,6 +10,7 @@
from .dependency_source import FactoryType
from .exceptions import (
ExitError,
NoContextValueError,
NoFactoryError,
)
from .provider import BaseProvider
Expand Down Expand Up @@ -109,6 +111,45 @@ async def get(
async with lock:
return await self._get_unlocked(key)

@overload
async def resolve_all(self, components: None = None) -> None: ...
@overload
async def resolve_all(self, components: Literal[True]) -> None: ...
@overload
async def resolve_all(self, components: Iterable[Component]) -> None: ...

async def resolve_all(self, components: Any = None) -> None:
"""
Resolve all container dependencies in the current scope for the given
components.

Examples:
>>> container.resolve_all()
Resolve all dependencies for the default component.

>>> container.resolve_all(True)
Resolve all dependencies for all components.

>>> container.resolve_all(['component1', 'component2'])
Resolve dependencies for 'component1' and 'component2'.
"""
if not components:

def component_check(k: DependencyKey) -> bool:
return k.component == DEFAULT_COMPONENT
elif components is True:

def component_check(k: DependencyKey) -> bool:
return True
else:

def component_check(k: DependencyKey) -> bool:
return k.component in components

for key in filter(component_check, self.registry.factories):
with suppress(NoContextValueError):
await self._get_unlocked(key)

async def _get_unlocked(self, key: DependencyKey) -> Any:
if key in self.context:
return self.context[key]
Expand Down
43 changes: 42 additions & 1 deletion src/dishka/container.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from contextlib import suppress
from threading import Lock
from typing import Any, Optional, TypeVar
from typing import Any, Iterable, Literal, Optional, TypeVar, overload

from dishka.entities.component import DEFAULT_COMPONENT, Component
from dishka.entities.key import DependencyKey
Expand All @@ -9,6 +10,7 @@
from .dependency_source import FactoryType
from .exceptions import (
ExitError,
NoContextValueError,
NoFactoryError,
)
from .provider import BaseProvider
Expand Down Expand Up @@ -107,6 +109,45 @@ def get(
with lock:
return self._get_unlocked(key)

@overload
def resolve_all(self, components: None = None) -> None: ...
@overload
def resolve_all(self, components: Literal[True]) -> None: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of that?

@overload
def resolve_all(self, components: Iterable[Component]) -> None: ...

def resolve_all(self, components: Any = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need specify components? What's the use case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to call parent container method as well?

"""
Resolve all container dependencies in the current scope for the given
components.

Examples:
>>> container.resolve_all()
Resolve all dependencies for the default component.

>>> container.resolve_all(True)
Resolve all dependencies for all components.

>>> container.resolve_all(['component1', 'component2'])
Resolve dependencies for 'component1' and 'component2'.
"""
if not components:

def component_check(k: DependencyKey) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get rid of nested functions

return k.component == DEFAULT_COMPONENT
elif components is True:

def component_check(k: DependencyKey) -> bool:
return True
else:

def component_check(k: DependencyKey) -> bool:
return k.component in components

for key in filter(component_check, self.registry.factories):
with suppress(NoContextValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

self._get_unlocked(key)

def _get_unlocked(self, key: DependencyKey) -> Any:
if key in self.context:
return self.context[key]
Expand Down
71 changes: 70 additions & 1 deletion tests/unit/container/test_components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Literal

import pytest

Expand Down Expand Up @@ -55,6 +55,24 @@ def foo(self, a: Annotated[int, FromComponent()]) -> float:
return a + 1


class YProvider(Provider):
scope = Scope.APP
component = "Y"

@provide
def foo(self) -> float:
return 42


class ZProvider(Provider):
scope = Scope.APP
component = "Z"

@provide
def foo(self) -> bool:
return True


def test_from_component():
container = make_container(MainProvider(20), XProvider())
assert container.get(complex) == 210
Expand All @@ -63,6 +81,31 @@ def test_from_component():
container.get(float)


@pytest.mark.parametrize(
("component", "expected_count"),
[
(None, 4),
(("",), 4),
(True, 6),
(("X",), 3),
(("X", ""), 4),
(("X", "Y"), 4),
(("X", "Y", ""), 5),
(("X", "Y", "Z"), 5),
(("X", "Y", "Z", ""), 6),
],
)
def test_from_component_resolve_all(
component: Literal[True] | tuple[Component] | None, expected_count: int
):
container = make_container(
MainProvider(20), XProvider(), YProvider(), ZProvider()
)
assert len(container.context) == 1
container.resolve_all(component)
assert len(container.context) == expected_count


@pytest.mark.asyncio()
async def test_from_component_async():
container = make_async_container(MainProvider(20), XProvider())
Expand All @@ -72,6 +115,32 @@ async def test_from_component_async():
await container.get(float)


@pytest.mark.parametrize(
("component", "expected_count"),
[
(None, 4),
(("",), 4),
(True, 6),
(("X",), 3),
(("X", ""), 4),
(("X", "Y"), 4),
(("X", "Y", ""), 5),
(("X", "Y", "Z"), 5),
(("X", "Y", "Z", ""), 6),
],
)
@pytest.mark.asyncio
async def test_from_component_resolve_all_async(
component: Literal[True] | tuple[Component] | None, expected_count: int
):
container = make_async_container(
MainProvider(20), XProvider(), YProvider(), ZProvider()
)
assert len(container.context) == 1
await container.resolve_all(component)
assert len(container.context) == expected_count


class SingleProvider(Provider):
scope = Scope.APP

Expand Down
50 changes: 50 additions & 0 deletions tests/unit/container/test_context_vars.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import pytest

from dishka import (
Expand All @@ -9,6 +11,7 @@
)
from dishka.dependency_source import from_context
from dishka.exceptions import NoContextValueError
from ..sample_providers import ClassA


def test_simple():
Expand All @@ -18,6 +21,35 @@ def test_simple():
assert container.get(int) == 1


class AProvider(Provider):
scope = Scope.APP
a = from_context(provides=int)
b = from_context(provides=str)

@provide
def foo(self, a: int) -> ClassA:
return ClassA(a)

@provide
def bar(self, a: str) -> bool:
return bool(a)


@pytest.mark.parametrize(
("context", "expected_count"),
[
({}, 1),
({int: 1}, 3),
({int: 1, str: "1"}, 5),
],
)
def test_simple_resolve_all(context: dict[type, Any], expected_count: int):
provider = AProvider()
container = make_container(provider, context=context)
container.resolve_all()
assert len(container.context) == expected_count


@pytest.mark.asyncio
async def test_simple_async():
provider = Provider()
Expand All @@ -26,6 +58,24 @@ async def test_simple_async():
assert await container.get(int) == 1


@pytest.mark.parametrize(
("context", "expected_count"),
[
({}, 1),
({int: 1}, 3),
({int: 1, str: "1"}, 5),
],
)
@pytest.mark.asyncio
async def test_simple_resolve_all_async(
context: dict[type, Any], expected_count: int
):
provider = AProvider()
container = make_async_container(provider, context=context)
await container.resolve_all()
assert len(container.context) == expected_count


def test_not_found():
provider = Provider()
provider.from_context(provides=int, scope=Scope.APP)
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/container/test_resolve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Callable

import pytest

from dishka import (
Expand Down Expand Up @@ -129,3 +131,66 @@ def test_external_method(method):

container = make_container(provider)
assert container.get(ClassA) is A_VALUE


@pytest.mark.parametrize(
("factory", "cache", "expected_count"),
[
(ClassA, True, 3),
(ClassA, False, 2),
(sync_func_a, True, 3),
(sync_func_a, False, 2),
(sync_iter_a, True, 3),
(sync_iter_a, False, 2),
(sync_gen_a, True, 3),
(sync_gen_a, False, 2),
],
)
def test_sync_resolve_all(
factory: Callable[..., Any], cache: bool, expected_count: int
):
class MyProvider(Provider):
a = provide(factory, scope=Scope.APP, cache=cache)

@provide(scope=Scope.APP)
def get_int(self) -> int:
return 100

container = make_container(MyProvider())
assert container.registry.scope is Scope.APP
assert len(container.context) == 1
container.resolve_all()
assert len(container.context) == expected_count
container.close()


@pytest.mark.parametrize(
("factory", "cache", "expected_count"),
[
(ClassA, True, 3),
(ClassA, False, 2),
(async_func_a, True, 3),
(async_func_a, False, 2),
(async_iter_a, True, 3),
(async_iter_a, False, 2),
(async_gen_a, True, 3),
(async_gen_a, False, 2),
],
)
@pytest.mark.asyncio
async def test_async_resolve_all(
factory: Callable[..., Any], cache: bool, expected_count: int
):
class MyProvider(Provider):
a = provide(factory, scope=Scope.APP, cache=cache)

@provide(scope=Scope.APP)
def get_int(self) -> int:
return 100

container = make_async_container(MyProvider())
assert container.registry.scope is Scope.APP
assert len(container.context) == 1
await container.resolve_all()
assert len(container.context) == expected_count
await container.close()