Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Process Model Search to Allow Relative Imports #903

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 24 additions & 72 deletions src/lava/magma/compiler/compiler_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,91 +812,43 @@ def _find_proc_models(proc: AbstractProcess) \
A list of all ProcessModels that implement the behaviour of `proc`.
"""

# Find all ProcModel classes that implement 'proc' in same module
proc_module = sys.modules[proc.__module__]
proc_models = \
ProcGroupDiGraphs._find_proc_models_in_module(proc, proc_module)

# Search for the file of the module.
file = None
if inspect.isclass(proc.__class__):
if hasattr(proc.__class__, '__module__'):
proc_ = sys.modules.get(proc.__class__.__module__)
# check if it has file (classes in jupyter nb do not)
if hasattr(proc_, '__file__'):
file = proc_.__file__
else:
raise TypeError('Source for {!r} not found'.format(object))

if file is None:
# class is probably defined in a jupyter notebook
# lookup file name per methods
for _, m in inspect.getmembers(proc.__class__):
if inspect.isfunction(m) and \
proc.__class__.__qualname__ + '.' \
+ m.__name__ == m.__qualname__:
file = inspect.getfile(m)
break
else:
file = inspect.getfile(proc.__class__)

# Find all ProcModel classes that implement 'proc' in the same
# directory and namespace module.

dir_names = [os.path.dirname(file)]

proc_module = inspect.getmodule(proc)
module_names = [proc_module.__name__,]
if not proc_module.__name__ == "__main__":
# Get the parent module.
module_spec = importlib.util.find_spec(proc_module.__name__)
if module_spec.parent != '':
parent_module = importlib.import_module(module_spec.parent)
# get module spec for parent module
parent_spec = importlib.util.find_spec(module_spec.parent)

# Get all the modules inside the parent (namespace) module.
# This is required here, because the namespace module can span
# multiple repositories.
namespace_module_infos = list(
pkgutil.iter_modules(
parent_module.__path__,
parent_module.__name__ + "."
parent_spec.submodule_search_locations,
# add parent module name for absolute import name
prefix=parent_spec.name + "."
)
)
module_names.extend(m.name for m in namespace_module_infos)

# Extract the directory name of each module.
for _, name, _ in namespace_module_infos:
module = importlib.import_module(name)
module_dir_name = os.path.dirname(inspect.getfile(module))
dir_names.append(module_dir_name)

# Go through all directories and extract all the ProcModels.
for dir_name in dir_names:
for _, name, _ in pkgutil.iter_modules([dir_name]):
import_path = os.path.join(dir_name, name)
name_not_dir = not os.path.isdir(import_path)
if name_not_dir:
spec = import_utils.spec_from_file_location(
name, os.path.join(dir_name, name + ".py"))
module = import_utils.module_from_spec(spec)
try:
spec.loader.exec_module(module)
if module != proc_module:
pm = ProcGroupDiGraphs._find_proc_models_in_module(
proc, module)
for proc_model in pm:
proc_cls_mod = \
inspect.getmodule(proc).__package__ + \
'.' + proc_model.__module__
proc_cls_mod = importlib. \
import_module(proc_cls_mod)
class_ = getattr(proc_cls_mod,
proc_model.__name__)
if class_ not in proc_models:
proc_models.append(class_)
except Exception:
warnings.warn(
f"Cannot import module '{module}' when searching "
f"ProcessModels for Process "
f"'{proc.__class__.__name__}'."
)
# check each module for proc modules
proc_models = []
for mod_name in module_names:
try:
module = importlib.import_module(mod_name)
pm = ProcGroupDiGraphs._find_proc_models_in_module(
proc, module)
for proc_model in pm:
if proc_model not in proc_models:
proc_models.append(proc_model)
except Exception:
warnings.warn(
f"Cannot import module '{module}' when searching "
f"ProcessModels for Process "
f"'{proc.__class__.__name__}'."
)

if not proc_models:
raise ex.NoProcessModelFound(proc)
Expand Down
56 changes: 56 additions & 0 deletions tests/lava/magma/compiler/test_compiler_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Unit tests for the `compiler_graphs` module.

Currently only includes tests for changes to the process model search process.

"""

import sys
import os
import unittest

from lava.proc.lif.process import LIF
from lava.proc.lif.models import (
PyLifModelFloat, PyLifModelBitAcc
)

# make sure parent dir is in python path
if os.path.dirname(__file__) not in sys.path:
sys.path.append(os.path.dirname(__file__))

from lava.magma.compiler.compiler_graphs import ProcGroupDiGraphs

from test_package.proc.test.process import TestProcess, TestModelSameFile
from test_package.proc.test.models_absolute import TestModelAbsolute
from test_package.proc.test.models_relative import TestModelRelative


class TestProcGroupDiGraphs(unittest.TestCase):
"""Testing ProcGroupDiGraphs.

Currently only tests `ProcGroupDiGraphs._find_proc_models()`
"""
def test_find_proc_models_custom_proc(self):
"""Test process model finding process."""
proc = TestProcess(name="test")

proc_models = ProcGroupDiGraphs._find_proc_models(proc)
expected_models = [
TestModelAbsolute,
TestModelRelative,
TestModelSameFile
]

self.assertTrue(all(pm in proc_models for pm in expected_models))

def test_find_proc_models_lava_proc(self):
"""Test process model finding for a standard lava process."""
lif = LIF(shape=(1,), name='lif')

proc_models = ProcGroupDiGraphs._find_proc_models(lif)
expected_models = [PyLifModelFloat, PyLifModelBitAcc]

self.assertTrue(all(pm in proc_models for pm in expected_models))


if __name__ == "__main__":
unittest.main()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Process models using absolute imports.

"""

import os
import sys
sys.path.append(os.path.abspath("../../../"))

from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.decorator import implements
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol

from test_package.proc.test.process import TestProcess


@implements(proc=TestProcess, protocol=LoihiProtocol)
class TestModelAbsolute(AbstractProcessModel):
"""Process model defined using absolute import of Process."""

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Process models using absolute imports.

"""

from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.decorator import implements
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol

from .process import TestProcess


@implements(proc=TestProcess, protocol=LoihiProtocol)
class TestModelRelative(AbstractProcessModel):
"""Process model defined using relative import of Process."""

17 changes: 17 additions & 0 deletions tests/lava/magma/compiler/test_package/proc/test/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Test process for testing process model search.

"""

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.decorator import implements
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol


class TestProcess(AbstractProcess):
"""Test process for proc-model search"""


@implements(proc=TestProcess, protocol=LoihiProtocol)
class TestModelSameFile(AbstractProcessModel):
"""Test process model in same file as process for proc-model search."""