Skip to content

Commit

Permalink
Implement dependency overriding
Browse files Browse the repository at this point in the history
  • Loading branch information
maldoinc committed Dec 21, 2023
1 parent cbc7b4a commit 371cedc
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 34 deletions.
74 changes: 69 additions & 5 deletions test/test_container_override.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import unittest
from unittest.mock import MagicMock

from test.fixtures import FooBase, FooBar
from test.services.no_annotations.random.random_service import RandomService
from wireup import DependencyContainer, ParameterBag
from unittest.mock import MagicMock, patch

from typing_extensions import Annotated
from wireup import DependencyContainer, ParameterBag, Wire
from wireup.ioc.types import ServiceOverride


class TestContainerOverride(unittest.TestCase):
Expand All @@ -14,9 +18,69 @@ def test_container_overrides_deps_service_locator(self):

random_mock = MagicMock()
random_mock.get_random.return_value = 5
self.assertEqual(random_mock.get_random(), 5)

with self.container.override(target=RandomService, new=random_mock):
with self.container.override.service(target=RandomService, new=random_mock):
svc = self.container.get(RandomService)

self.assertEqual(svc.get_random(), 5)

random_mock.get_random.assert_called_once()
self.assertEqual(self.container.get(RandomService).get_random(), 4)

def test_container_overrides_deps_service_locator_interface(self):
self.container.abstract(FooBase)

foo_mock = MagicMock()

with patch.object(foo_mock, "foo", new="mock"):
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")

def test_container_override_many_with_qualifier(self):
self.container.register(RandomService, qualifier="Rand1")
self.container.register(RandomService, qualifier="Rand2")

@self.container.autowire
def target(
rand1: Annotated[RandomService, Wire(qualifier="Rand1")],
rand2: Annotated[RandomService, Wire(qualifier="Rand2")],
):
self.assertEqual(rand1.get_random(), 5)
self.assertEqual(rand2.get_random(), 6)

self.assertIsInstance(rand1, MagicMock)
self.assertIsInstance(rand2, MagicMock)

rand1_mock = MagicMock()
rand1_mock.get_random.return_value = 5

rand2_mock = MagicMock()
rand2_mock.get_random.return_value = 6

overrides = [
ServiceOverride(target=RandomService, qualifier="Rand1", new=rand1_mock),
ServiceOverride(target=RandomService, qualifier="Rand2", new=rand2_mock),
]
with self.container.override.services(overrides=overrides):
target()

rand1_mock.get_random.assert_called_once()
rand2_mock.get_random.assert_called_once()

def test_container_override_with_interface(self):
self.container.abstract(FooBase)
self.container.register(FooBar)

@self.container.autowire
def target(foo: FooBase):
self.assertEqual(foo.foo, "mock")
self.assertIsInstance(foo, MagicMock)

foo_mock = MagicMock()

with patch.object(foo_mock, "foo", new="mock"):
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")

target()
3 changes: 2 additions & 1 deletion wireup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from wireup.import_util import register_all_in_module, warmup_container
from wireup.ioc.dependency_container import DependencyContainer
from wireup.ioc.parameter import ParameterBag
from wireup.ioc.types import ParameterReference, ServiceLifetime
from wireup.ioc.types import ParameterReference, ServiceLifetime, ServiceOverride

container = DependencyContainer(ParameterBag())
"""Singleton DI container instance.
Expand All @@ -17,6 +17,7 @@
"ParameterEnum",
"ParameterReference",
"ServiceLifetime",
"ServiceOverride",
"Wire",
"container",
"register_all_in_module",
Expand Down
52 changes: 24 additions & 28 deletions wireup/ioc/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
import sys
from contextlib import contextmanager

Check failure on line 6 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

wireup/ioc/dependency_container.py:6:24: F401 `contextlib.contextmanager` imported but unused

Check failure on line 6 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

wireup/ioc/dependency_container.py:6:24: F401 `contextlib.contextmanager` imported but unused

Check failure on line 6 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F401)

wireup/ioc/dependency_container.py:6:24: F401 `contextlib.contextmanager` imported but unused

Check failure on line 6 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

wireup/ioc/dependency_container.py:6:24: F401 `contextlib.contextmanager` imported but unused
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterator, Tuple, TypeVar, overload

Check failure on line 7 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

wireup/ioc/dependency_container.py:7:50: F401 `typing.Iterator` imported but unused

Check failure on line 7 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

wireup/ioc/dependency_container.py:7:50: F401 `typing.Iterator` imported but unused

Check failure on line 7 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F401)

wireup/ioc/dependency_container.py:7:50: F401 `typing.Iterator` imported but unused

Check failure on line 7 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

wireup/ioc/dependency_container.py:7:50: F401 `typing.Iterator` imported but unused

from .override_manager import OverrideManager

if sys.version_info[:2] > (3, 8):
from graphlib import TopologicalSorter
Expand Down Expand Up @@ -36,7 +38,7 @@


__T = TypeVar("__T")
__ObjectIdentifier = tuple[type, ContainerProxyQualifierValue]
__ObjectIdentifier = Tuple[type, ContainerProxyQualifierValue]


class DependencyContainer:
Expand All @@ -61,6 +63,7 @@ class DependencyContainer:
"__initialized_proxies",
"__buildable_types",
"__active_overrides",
"__override_manager",
"__params",
)

Expand All @@ -72,6 +75,7 @@ def __init__(self, parameter_bag: ParameterBag) -> None:
self.__initialized_proxies: dict[__ObjectIdentifier, ContainerProxy[Any]] = {}
self.__buildable_types: set[type] = set()
self.__params: ParameterBag = parameter_bag
self.__override_manager: OverrideManager = OverrideManager(self.__active_overrides)

def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None) -> __T:
"""Get an instance of the requested type.
Expand All @@ -82,7 +86,11 @@ def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None)
:param klass: Class of the dependency already registered in the container.
:return: An instance of the requested object. Always returns an existing instance when one is available.
"""
if res := self.__active_overrides.get((klass, qualifier)):
return res # type: ignore[no-any-return]

self.__assert_dependency_exists(klass, qualifier)

if self.__service_registry.is_interface_known(klass):
klass = self.__resolve_impl(klass, qualifier)

Expand Down Expand Up @@ -202,24 +210,30 @@ def warmup(self) -> None:
if (klass, qualifier) not in self.__initialized_objects:
self.__create_instance(klass, qualifier)

@property
def override(self) -> OverrideManager:
"""Override container services.
Injection requests to overriden services will instead return the new values while the override is active."""

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (D205)

wireup/ioc/dependency_container.py:215:9: D205 1 blank line required between summary line and description

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (D209)

wireup/ioc/dependency_container.py:215:9: D209 Multi-line docstring closing quotes should be on a separate line

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (D205)

wireup/ioc/dependency_container.py:215:9: D205 1 blank line required between summary line and description

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (D209)

wireup/ioc/dependency_container.py:215:9: D209 Multi-line docstring closing quotes should be on a separate line

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D205)

wireup/ioc/dependency_container.py:215:9: D205 1 blank line required between summary line and description

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D209)

wireup/ioc/dependency_container.py:215:9: D209 Multi-line docstring closing quotes should be on a separate line

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D205)

wireup/ioc/dependency_container.py:215:9: D205 1 blank line required between summary line and description

Check failure on line 216 in wireup/ioc/dependency_container.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D209)

wireup/ioc/dependency_container.py:215:9: D209 Multi-line docstring closing quotes should be on a separate line
return self.__override_manager

def __callable_get_params_to_inject(self, fn: AnyCallable) -> dict[str, Any]:
values_from_parameters: dict[str, Any] = {}
params = self.__service_registry.context.dependencies[fn]
names_to_remove: set[str] = set()

for name, annotated_parameter in params.items():
for name, param in params.items():
# This block is particularly crucial for performance and has to be written to be as fast as possible.

# Check if there's already an instantiated object with this id which can be directly injected
obj_id = annotated_parameter.klass, annotated_parameter.qualifier_value
obj_id = param.klass, param.qualifier_value

if obj := self.__initialized_objects.get(obj_id): # type: ignore[arg-type]
if param.klass and (obj := self.__active_overrides.get(obj_id, self.__initialized_objects.get(obj_id))): # type: ignore[arg-type]
values_from_parameters[name] = obj
# Dealing with parameter, return the value as we cannot proxy int str etc.
# We don't want to check here for none because as long as it exists in the bag, the value is good.
elif isinstance(annotated_parameter.annotation, ParameterWrapper):
values_from_parameters[name] = self.params.get(annotated_parameter.annotation.param)
elif annotated_parameter.klass and (
obj := self.__initialize_container_proxy_object_from_parameter(annotated_parameter)
):
elif isinstance(param.annotation, ParameterWrapper):
values_from_parameters[name] = self.params.get(param.annotation.param)
elif param.klass and (obj := self.__initialize_container_proxy_object_from_parameter(param)):
values_from_parameters[name] = obj
else:
names_to_remove.add(name)
Expand Down Expand Up @@ -331,21 +345,3 @@ def __assert_dependency_exists(self, klass: type, qualifier: ContainerProxyQuali
"""Assert that there exists an impl with that qualifier or an interface with an impl and the same qualifier."""
if not self.__service_registry.is_type_with_qualifier_known(klass, qualifier):
raise UnknownServiceRequestedError(klass)

@contextmanager
def override(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> None:
try:
self.__active_overrides[target, qualifier] = new
yield
finally:
del self.__active_overrides[target, qualifier]

@contextmanager
def override_many(self, overrides: list[ServiceOverride]) -> None:
try:
for override in overrides:
self.__active_overrides[(override.target, override.qualifier)] = override.new
yield
finally:
for override in overrides:
del self.__active_overrides[(override.target, override.qualifier)]
38 changes: 38 additions & 0 deletions wireup/ioc/override_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Iterator

from wireup.ioc.types import ContainerProxyQualifierValue, ServiceOverride

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:30: TCH001 Move application import `wireup.ioc.types.ContainerProxyQualifierValue` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:60: TCH001 Move application import `wireup.ioc.types.ServiceOverride` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:30: TCH001 Move application import `wireup.ioc.types.ContainerProxyQualifierValue` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:60: TCH001 Move application import `wireup.ioc.types.ServiceOverride` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:30: TCH001 Move application import `wireup.ioc.types.ContainerProxyQualifierValue` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:60: TCH001 Move application import `wireup.ioc.types.ServiceOverride` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:30: TCH001 Move application import `wireup.ioc.types.ContainerProxyQualifierValue` into a type-checking block

Check failure on line 6 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (TCH001)

wireup/ioc/override_manager.py:6:60: TCH001 Move application import `wireup.ioc.types.ServiceOverride` into a type-checking block


class OverrideManager:

Check failure on line 9 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (D101)

wireup/ioc/override_manager.py:9:7: D101 Missing docstring in public class

Check failure on line 9 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (D101)

wireup/ioc/override_manager.py:9:7: D101 Missing docstring in public class

Check failure on line 9 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D101)

wireup/ioc/override_manager.py:9:7: D101 Missing docstring in public class

Check failure on line 9 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D101)

wireup/ioc/override_manager.py:9:7: D101 Missing docstring in public class
def __init__(self, active_overrides: dict[(type, ContainerProxyQualifierValue), Any]):

Check failure on line 10 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (ANN204)

wireup/ioc/override_manager.py:10:9: ANN204 Missing return type annotation for special method `__init__`

Check failure on line 10 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (ANN204)

wireup/ioc/override_manager.py:10:9: ANN204 Missing return type annotation for special method `__init__`

Check failure on line 10 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (ANN204)

wireup/ioc/override_manager.py:10:9: ANN204 Missing return type annotation for special method `__init__`

Check failure on line 10 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (ANN204)

wireup/ioc/override_manager.py:10:9: ANN204 Missing return type annotation for special method `__init__`
self.__active_overrides = active_overrides

def set(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None):

Check failure on line 13 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (A003)

wireup/ioc/override_manager.py:13:9: A003 Class attribute `set` is shadowing a Python builtin

Check failure on line 13 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (A003)

wireup/ioc/override_manager.py:13:9: A003 Class attribute `set` is shadowing a Python builtin

Check failure on line 13 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (A003)

wireup/ioc/override_manager.py:13:9: A003 Class attribute `set` is shadowing a Python builtin

Check failure on line 13 in wireup/ioc/override_manager.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (A003)

wireup/ioc/override_manager.py:13:9: A003 Class attribute `set` is shadowing a Python builtin
self.__active_overrides[target, qualifier] = new

def delete(self, target: type, qualifier: ContainerProxyQualifierValue = None):
if (target, qualifier) in self.__active_overrides:
del self.__active_overrides[target, qualifier]

@contextmanager
def service(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> Iterator[None]:
"""Override the target service with new for the duration of the context manager."""
try:
self.set(target, new, qualifier)
yield
finally:
self.delete(target, qualifier)

@contextmanager
def services(self, overrides: list[ServiceOverride]) -> Iterator[None]:
"""Override the target service with new for the duration of the context manager."""
try:
for override in overrides:
self.set(override.target, override.new, override.qualifier)
yield
finally:
for override in overrides:
self.delete(override.target, override.qualifier)
2 changes: 2 additions & 0 deletions wireup/ioc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __hash__(self) -> int:

@dataclass(frozen=True, eq=True)
class ServiceOverride:
"""Data class to represent a service override. Target type will be replaced with the new type by the container."""

target: type
qualifier: ContainerProxyQualifierValue
new: Any

0 comments on commit 371cedc

Please sign in to comment.