Skip to content

Commit

Permalink
Override will now raise if target is unknown.
Browse files Browse the repository at this point in the history
  • Loading branch information
maldoinc committed Apr 6, 2024
1 parent b31a191 commit 2a7b8c5
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
21 changes: 16 additions & 5 deletions test/unit/test_container_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from test.unit.services.no_annotations.random.random_service import RandomService
from wireup import DependencyContainer, ParameterBag, Wire
from wireup.errors import UnknownOverrideRequestedError
from wireup.ioc.override_manager import OverrideManager
from wireup.ioc.types import ServiceOverride

Expand All @@ -30,13 +31,14 @@ def test_container_overrides_deps_service_locator(self):

def test_container_overrides_deps_service_locator_interface(self):
self.container.abstract(FooBase)
self.container.register(FooBar)

foo_mock = MagicMock()
foo_mock.foo = "mock"

with patch.object(foo_mock, "foo", new="mock"):
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")

def test_container_override_many_with_qualifier(self):
self.container.register(RandomService, qualifier="Rand1")
Expand Down Expand Up @@ -90,9 +92,18 @@ def target(foo: FooBase):
def test_clear_services_removes_all(self):
overrides = {}
mock1 = MagicMock()
override_mgr = OverrideManager(overrides)
override_mgr = OverrideManager(overrides, lambda _klass, _qualifier: True)
override_mgr.set(RandomService, new=mock1)
self.assertEqual(overrides, {(RandomService, None): mock1})

override_mgr.clear()
self.assertEqual(overrides, {})

def test_raises_on_unknown_override(self):
with self.assertRaises(UnknownOverrideRequestedError) as e:
with self.container.override.service(target=unittest.TestCase, qualifier="foo", new=MagicMock()):
pass

self.assertEqual(
str(e.exception), "Cannot override unknown <class 'unittest.case.TestCase'> with qualifier 'foo'."
)
7 changes: 7 additions & 0 deletions wireup/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,10 @@ class InvalidRegistrationTypeError(WireupError):

def __init__(self, attempted: Any) -> None:
super().__init__(f"Cannot register {attempted} with the container. Allowed types are callables and types")


class UnknownOverrideRequestedError(WireupError):
"""Raised when attempting to override a service which does not exist."""

def __init__(self, klass: type, qualifier: ContainerProxyQualifierValue) -> None:
super().__init__(f"Cannot override unknown {klass} with qualifier '{qualifier}'.")
5 changes: 3 additions & 2 deletions wireup/ioc/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from .initialization_context import InitializationContext
from .parameter import ParameterBag


__T = TypeVar("__T")
__ObjectIdentifier = Tuple[type, ContainerProxyQualifierValue]

Expand Down Expand Up @@ -73,7 +72,9 @@ def __init__(self, parameter_bag: ParameterBag) -> None:
self.__initialized_proxies: dict[__ObjectIdentifier, ContainerProxy[Any]] = {}
self.__buildable_types: set[type] = set()
self.__params: ParameterBag = parameter_bag
self.__override_manager: OverrideManager = OverrideManager(self.__active_overrides)
self.__override_manager: OverrideManager = OverrideManager(
self.__active_overrides, self.__service_registry.is_type_with_qualifier_known
)

def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None) -> __T:
"""Get an instance of the requested type.
Expand Down
14 changes: 13 additions & 1 deletion wireup/ioc/override_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator

from wireup.errors import UnknownOverrideRequestedError

if TYPE_CHECKING:
from collections.abc import Callable

from wireup.ioc.types import ContainerProxyQualifierValue, ServiceOverride


class OverrideManager:
"""Enables overriding of services registered with the container."""

def __init__(self, active_overrides: dict[tuple[type, ContainerProxyQualifierValue], Any]) -> None:
def __init__(
self,
active_overrides: dict[tuple[type, ContainerProxyQualifierValue], Any],
is_valid_override: Callable[[type, ContainerProxyQualifierValue], bool],
) -> None:
self.__is_valid_override = is_valid_override
self.__active_overrides = active_overrides

def set(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> None:
Expand All @@ -23,6 +32,9 @@ def set(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue =
with the qualifier parameter set to a value.
:param new: The new object to be injected instead of `target`.
"""
if not self.__is_valid_override(target, qualifier):
raise UnknownOverrideRequestedError(klass=target, qualifier=qualifier)

self.__active_overrides[target, qualifier] = new

def delete(self, target: type, qualifier: ContainerProxyQualifierValue = None) -> None:
Expand Down

0 comments on commit 2a7b8c5

Please sign in to comment.