diff --git a/CHANGELOG.md b/CHANGELOG.md index 54ff35742..ebe740cd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The web UI now renders the step graph left-to-right. - The web UI now shows runs by date, with the most recent run at the top. - The web UI now shows steps in a color-coded way. +- The `--include-package` flag now also accepts paths instead of module names. ### Fixed diff --git a/tango/common/util.py b/tango/common/util.py index 61cac1c8c..78aff245b 100644 --- a/tango/common/util.py +++ b/tango/common/util.py @@ -6,7 +6,7 @@ import traceback from contextlib import contextmanager from pathlib import Path -from typing import Iterable, Optional, Set, Union +from typing import Iterable, Optional, Set, Tuple, Union from .aliases import PathOrStr from .exceptions import SigTermReceived @@ -52,6 +52,35 @@ def import_extra_module(package_name: str) -> None: _extra_imported_modules.add(package_name) +def resolve_module_name(package_name: str) -> Tuple[str, Path]: + base_path = Path(".") + package_path = Path(package_name) + if not package_path.exists(): + raise ValueError(f"'{package_path}' looks like a path, but the path does not exist") + + parent = package_path.parent + while parent != parent.parent: + if (parent / "__init__.py").is_file(): + parent = parent.parent + else: + base_path = parent + break + + package_name = str(package_path.relative_to(base_path)).replace("/", ".") + + if package_path.is_file(): + if package_path.name == "__init__.py": + # If `__init__.py` file, resolve to the parent module. + package_name = package_name[: -len(".__init__.py")] + elif package_name.endswith(".py"): + package_name = package_name[:-3] + + if not package_name: + raise ValueError(f"invalid package path '{package_path}'") + + return package_name, base_path + + def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] = None) -> None: """ Import all submodules under the given package. @@ -59,6 +88,12 @@ def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] Primarily useful so that people using tango can specify their own custom packages and have their custom classes get loaded and registered. """ + # If `package_name` is in the form of a path, convert to the module format. + if "/" in package_name or package_name.endswith(".py"): + package_name, base_path = resolve_module_name(package_name) + else: + base_path = Path(".") + if exclude and package_name in exclude: return @@ -67,7 +102,7 @@ def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] # For some reason, python doesn't always add this by default to your path, but you pretty much # always want it when using `--include-package`. And if it's already there, adding it again at # the end won't hurt anything. - with push_python_path("."): + with push_python_path(base_path): # Import at top level module = importlib.import_module(package_name) path = getattr(module, "__path__", []) diff --git a/tests/common/util_test.py b/tests/common/util_test.py index 12782fa7e..8fc463849 100644 --- a/tests/common/util_test.py +++ b/tests/common/util_test.py @@ -1,4 +1,5 @@ import time +from pathlib import Path import pytest from flaky import flaky @@ -7,10 +8,34 @@ could_be_class_name, find_integrations, find_submodules, + resolve_module_name, threaded_generator, ) +@pytest.mark.parametrize( + "package_name, resolved_package_name, resolved_base_path", + [ + ( + "tango/integrations/datasets/__init__.py", + "tango.integrations.datasets", + Path("."), + ), + ( + "tango/__init__.py", + "tango", + Path("."), + ), + ("tango/steps/dataset_remix.py", "tango.steps.dataset_remix", Path(".")), + ("examples/train_gpt2/components.py", "components", Path("examples/train_gpt2/")), + ], +) +def test_resolve_module_name( + package_name: str, resolved_package_name: str, resolved_base_path: Path +): + assert resolve_module_name(package_name) == (resolved_package_name, resolved_base_path) + + def test_find_submodules(): assert "tango.version" in set(find_submodules()) assert "tango.common.registrable" in set(find_submodules())