diff --git a/junkie/__init__.py b/junkie/__init__.py index 519af27..eef6006 100644 --- a/junkie/__init__.py +++ b/junkie/__init__.py @@ -1 +1,2 @@ from junkie._junkie import Junkie, JunkieError +from junkie._junkie import inject_list diff --git a/junkie/_junkie.py b/junkie/_junkie.py index 7427063..b6e54e3 100644 --- a/junkie/_junkie.py +++ b/junkie/_junkie.py @@ -22,6 +22,8 @@ def __init__(self, instances_and_factories: Mapping[str, Any] = None): self._instances_by_name = None self._instances_by_name_stack = [{}] + self._mapping["_junkie"] = self + @contextmanager def inject(self, *names_and_factories: Union[str, Callable]) -> Union[Any, Tuple[Any]]: LOGGER.debug("inject(%s)", Junkie._LogParams(*names_and_factories)) @@ -134,3 +136,16 @@ def __str__(self): arg_params = list(map(str, self.args)) kwarg_params = list(map(str, [f"{key}={repr(value)}" for key, value in self.kwargs.items()])) return ", ".join(arg_params + kwarg_params) + + +def inject_list(*factories_or_names): + """Can be used within the context to let junkie create a list of instances from a list of factories or names""" + @contextmanager + def wrapper(_junkie: Junkie): + with _junkie.inject(*factories_or_names) as instances: + if isinstance(instances, tuple): + yield list(instances) + else: + yield [instances] + + return wrapper diff --git a/test/test_junkie.py b/test/test_junkie.py index 76e14f9..a90ab4b 100644 --- a/test/test_junkie.py +++ b/test/test_junkie.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from unittest import skipIf -from junkie import Junkie, JunkieError +from junkie import Junkie, JunkieError, inject_list class JunkieTest(unittest.TestCase): @@ -417,3 +417,28 @@ def __init__(self, a: A1): with Junkie().inject(C, B) as (c, b): self.assertIs(b.a, c.a) self.assertIsInstance(b.a, A1) + + def test_inject_junkie_reference(self): + my_junkie = Junkie() + + with my_junkie.inject("_junkie") as injected_junkie: + self.assertIs(injected_junkie, my_junkie) + + def test_inject_list(self): + class A: + pass + + class B: + def __init__(self, a: A, some_value: str): + self.a = a + self.some_value = some_value + + context = {"some_value": "value", "my_list_1": inject_list(A, B), "my_list_2": inject_list(B)} + + with Junkie(context).inject("my_list_1", "my_list_2") as (my_list_1, my_list_2): + self.assertIsInstance(my_list_1, list) + self.assertIsInstance(my_list_2, list) + self.assertIsInstance(my_list_1[0], A) + self.assertIsInstance(my_list_1[1].a, A) + self.assertEqual("value", my_list_1[1].some_value) + self.assertEqual("value", my_list_2[0].some_value)