Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617983897
Change-Id: I7ea6997fa4294e407fb6323bd41cadc3a61a586f
GitOrigin-RevId: 4d98a43ed847ff4940e2644144c7feecf60741ee
  • Loading branch information
DeepMind Team authored and alpiccioni committed Dec 4, 2024
1 parent 1553365 commit 1a72048
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,28 @@
"""Test for custom lazy loader definition for xmanager sub-package __init__.py files."""

import unittest
from xmanager import module_lazy_loader
from xmanager.module_lazy_loader import module_lazy_loader


class ModuleLazyLoaderTest(unittest.TestCase):
class LazyLoaderModuleAttrsTest(unittest.TestCase):
test_lazy_loader = module_lazy_loader.XManagerLazyLoader(
__name__,
apis=[
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
module="xmanager.module_lazy_loader.module_lazy_loader",
symbol="XManagerAPI",
alias="boo",
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
module="xmanager.module_lazy_loader.module_lazy_loader",
symbol="XManagerAPI",
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader", alias="baz"
module="xmanager.module_lazy_loader.module_lazy_loader",
alias="baz",
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
module="xmanager.module_lazy_loader.module_lazy_loader",
),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom lazy loader definition for xmanager sub-package __init__.py files."""

import dataclasses
import importlib
import sys
import types
from typing import Any, Callable, Optional, Sequence

Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(self, subpackage_name: str, apis: Sequence[XManagerAPI]):
self.apis = apis
self._loaded_attrs = {}
self._name_to_api: dict[str, XManagerAPI] = {}

for api in self.apis:
if api.alias:
name = api.alias
Expand All @@ -72,18 +75,35 @@ def get_module_getattr(
) -> Callable[[str], types.ModuleType | Any | None]:
"""Returns __getattr__ for the xmanager sub-package __init__.py file."""

def _import_module_with_reloaded_parent(
module_name: str, e: ModuleNotFoundError
):
# reload module's parent as a last resort (likely in the case that a
# module was imported outside adhoc import context but later
# used within it). Assuming the parent package has a lazy-loaded
# / empty __init__.py file, this should be quick.
parent = e.name.rsplit(".", 1)[0]
parent_module = importlib.import_module(parent)
importlib.reload(parent_module)
return importlib.import_module(module_name)

def _import_module(module_name: str):
try:
return importlib.import_module(module_name)
except ModuleNotFoundError as e:
return _import_module_with_reloaded_parent(e)

def _module_getattr(name: str) -> types.ModuleType | Any | None:
if name in self._loaded_attrs:
return self._loaded_attrs[name]
if name in self._name_to_api:
api = self._name_to_api[name]
module = _import_module(api.module)
if api.symbol:
module = importlib.import_module(api.module)
attr = getattr(module, api.symbol)
self._loaded_attrs[name] = attr
return attr
else:
module = importlib.import_module(api.module)
self._loaded_attrs[name] = module
return module
raise AttributeError(
Expand Down

0 comments on commit 1a72048

Please sign in to comment.