Skip to content

Commit

Permalink
make --include-package accept paths (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jan 25, 2022
1 parent 92b8fe5 commit 4011482
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The web UI now renders the step graph left-to-right.
- The web UI now shows runs by date, with the most recent run at the top.
- The web UI now shows steps in a color-coded way.
- The `--include-package` flag now also accepts paths instead of module names.

### Fixed

Expand Down
39 changes: 37 additions & 2 deletions tango/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable, Optional, Set, Union
from typing import Iterable, Optional, Set, Tuple, Union

from .aliases import PathOrStr
from .exceptions import SigTermReceived
Expand Down Expand Up @@ -52,13 +52,48 @@ def import_extra_module(package_name: str) -> None:
_extra_imported_modules.add(package_name)


def resolve_module_name(package_name: str) -> Tuple[str, Path]:
base_path = Path(".")
package_path = Path(package_name)
if not package_path.exists():
raise ValueError(f"'{package_path}' looks like a path, but the path does not exist")

parent = package_path.parent
while parent != parent.parent:
if (parent / "__init__.py").is_file():
parent = parent.parent
else:
base_path = parent
break

package_name = str(package_path.relative_to(base_path)).replace("/", ".")

if package_path.is_file():
if package_path.name == "__init__.py":
# If `__init__.py` file, resolve to the parent module.
package_name = package_name[: -len(".__init__.py")]
elif package_name.endswith(".py"):
package_name = package_name[:-3]

if not package_name:
raise ValueError(f"invalid package path '{package_path}'")

return package_name, base_path


def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] = None) -> None:
"""
Import all submodules under the given package.
Primarily useful so that people using tango can specify their own custom packages
and have their custom classes get loaded and registered.
"""
# If `package_name` is in the form of a path, convert to the module format.
if "/" in package_name or package_name.endswith(".py"):
package_name, base_path = resolve_module_name(package_name)
else:
base_path = Path(".")

if exclude and package_name in exclude:
return

Expand All @@ -67,7 +102,7 @@ def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]]
# For some reason, python doesn't always add this by default to your path, but you pretty much
# always want it when using `--include-package`. And if it's already there, adding it again at
# the end won't hurt anything.
with push_python_path("."):
with push_python_path(base_path):
# Import at top level
module = importlib.import_module(package_name)
path = getattr(module, "__path__", [])
Expand Down
25 changes: 25 additions & 0 deletions tests/common/util_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from pathlib import Path

import pytest
from flaky import flaky
Expand All @@ -7,10 +8,34 @@
could_be_class_name,
find_integrations,
find_submodules,
resolve_module_name,
threaded_generator,
)


@pytest.mark.parametrize(
"package_name, resolved_package_name, resolved_base_path",
[
(
"tango/integrations/datasets/__init__.py",
"tango.integrations.datasets",
Path("."),
),
(
"tango/__init__.py",
"tango",
Path("."),
),
("tango/steps/dataset_remix.py", "tango.steps.dataset_remix", Path(".")),
("examples/train_gpt2/components.py", "components", Path("examples/train_gpt2/")),
],
)
def test_resolve_module_name(
package_name: str, resolved_package_name: str, resolved_base_path: Path
):
assert resolve_module_name(package_name) == (resolved_package_name, resolved_base_path)


def test_find_submodules():
assert "tango.version" in set(find_submodules())
assert "tango.common.registrable" in set(find_submodules())
Expand Down

0 comments on commit 4011482

Please sign in to comment.