Skip to content

Commit

Permalink
Lazy loader definition for XManager sub-package __init__.py files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611589603
Change-Id: I9c27e689a2b319316332fd1e7c7c119047cd768c
GitOrigin-RevId: 44632ca9f0d11ba6330bed45ba2c05222b8edd66
  • Loading branch information
fionalang authored and alpiccioni committed Dec 4, 2024
1 parent 3b8f0b4 commit 407b98c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
93 changes: 93 additions & 0 deletions xmanager/module_lazy_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom lazy loader definition for xmanager sub-package __init__.py files."""
import dataclasses
import importlib
import types
from typing import Any, Callable, Optional, Sequence


@dataclasses.dataclass
class XManagerAPI:
"""Dataclass for XManager sub-package APIs.
Attributes:
module: An XManager submodule, which will be the exposed API if no symbol is
provided.
symbol: An optional symbol to expose from the given module. If provided, the
API will be module.symbol.
alias: An optional alias to rename the API.
"""

module: str
symbol: Optional[str] = None
alias: Optional[str] = None


class XManagerLazyLoader:
"""Custom lazy loader for xmanager sub-package __init__.py files."""

def __init__(self, subpackage_name: str, apis: Sequence[XManagerAPI]):
"""Initializes the XManagerLazyLoader.
Args:
subpackage_name: The name of the current xmanager sub-package (i.e. the
__name__ attribute of the current xmanager sub-package).
apis: A list of XManagerAPIs to expose from the XManager sub-package.
"""
self.subpackage_name = subpackage_name
self.apis = apis
self._loaded_attrs = {}
self._name_to_api: dict[str, XManagerAPI] = {}
for api in self.apis:
if api.alias:
name = api.alias
elif api.symbol:
name = api.symbol
else:
name = api.module.split(".")[-1] # module name
self._name_to_api[name] = api

def get_module_all(self) -> list[str]:
"""Returns __all__ for the xmanager sub-package __init__.py file."""
return list(self._name_to_api.keys())

def get_module_dir(self) -> Callable[[], list[str]]:
"""Returns __dir__ for the xmanager sub-package __init__.py file."""
return lambda: sorted(self._name_to_api.keys())

def get_module_getattr(
self,
) -> Callable[[str], types.ModuleType | Any | None]:
"""Returns __getattr__ for the xmanager sub-package __init__.py file."""

def _module_getattr(name: str) -> types.ModuleType | Any | None:
if name in self._loaded_attrs:
return self._loaded_attrs[name]
if name in self._name_to_api:
api = self._name_to_api[name]
if api.symbol:
module = importlib.import_module(api.module)
attr = getattr(module, api.symbol)
self._loaded_attrs[name] = attr
return attr
else:
module = importlib.import_module(api.module)
self._loaded_attrs[name] = module
return module
raise AttributeError(
f"module {self.subpackage_name!r} has no attribute {name!r}"
)

return _module_getattr
66 changes: 66 additions & 0 deletions xmanager/module_lazy_loader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for custom lazy loader definition for xmanager sub-package __init__.py files."""

import unittest
from xmanager import module_lazy_loader


class ModuleLazyLoaderTest(unittest.TestCase):
test_lazy_loader = module_lazy_loader.XManagerLazyLoader(
__name__,
apis=[
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
symbol="XManagerAPI",
alias="boo",
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
symbol="XManagerAPI",
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader", alias="baz"
),
module_lazy_loader.XManagerAPI(
module="xmanager.module_lazy_loader",
),
],
)

def test_all(self):
self.assertCountEqual(
self.test_lazy_loader.get_module_all(),
["boo", "XManagerAPI", "baz", "module_lazy_loader"],
)

def test_dir(self):
self.assertCountEqual(
self.test_lazy_loader.get_module_dir()(),
["boo", "XManagerAPI", "baz", "module_lazy_loader"],
)

def test_getattr(self):
local_getattr = self.test_lazy_loader.get_module_getattr()
self.assertEqual(local_getattr("boo"), module_lazy_loader.XManagerAPI)
self.assertEqual(
local_getattr("XManagerAPI"), module_lazy_loader.XManagerAPI
)
self.assertEqual(local_getattr("baz"), module_lazy_loader)
self.assertEqual(local_getattr("module_lazy_loader"), module_lazy_loader)
self.assertRaises(AttributeError, local_getattr, "this_attr_does_not_exist")


if __name__ == "__main__":
unittest.main()

0 comments on commit 407b98c

Please sign in to comment.