Skip to content

Commit

Permalink
Fix importing file modules
Browse files Browse the repository at this point in the history
  • Loading branch information
maldoinc committed Aug 2, 2024
1 parent 5d14568 commit 735ed21
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 12 additions & 0 deletions test/unit/test_module_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from test.unit.services import no_annotations, with_annotations
from test.unit.services.no_annotations.random.random_service import RandomService
from test.unit.services.no_annotations.random.truly_random_service import TrulyRandomService
from test.unit.services.with_annotations import services
from test.unit.services.with_annotations.env import EnvService
from test.unit.services.with_annotations.services import IFoo

from wireup import DependencyContainer, ParameterBag, register_all_in_module, warmup_container
from wireup.import_util import initialize_container


class ModuleLoadingTest(unittest.TestCase):
Expand All @@ -27,3 +29,13 @@ def test_warmup_loads_all_in_module_with_annotations(self):
self.assertEqual("foo", container.get(IFoo).get_foo())
self.assertEqual(5, container.get(TrulyRandomService, qualifier="foo").get_truly_random())
self.assertEqual(4, container.get(RandomService, qualifier="foo").get_random())

def test_loads_module_is_file(self):
# Assert that loading works when the module is a file instead of the entire module
container = DependencyContainer(ParameterBag())
container.params.put("env_name", "dev")
initialize_container(container, service_modules=[services])

self.assertEqual("foo", container.get(services.IFoo).get_foo())
self.assertEqual(4, container.get(RandomService, qualifier="foo").get_random())
self.assertEqual(5, container.get(TrulyRandomService, qualifier="foo").get_truly_random())
5 changes: 4 additions & 1 deletion wireup/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def _find_in_path(path: Path, parent_module_name: str) -> None:
classes.update(_module_get_objects(sub_module))

if f := module.__file__:
_find_in_path(Path(f).parent, module.__name__)
if f.endswith("__init__.py"):
_find_in_path(Path(f).parent, module.__name__)
else:
classes.update(_module_get_objects(module))

return classes

Expand Down

0 comments on commit 735ed21

Please sign in to comment.