-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lazy loader definition for XManager sub-package __init__.py files.
PiperOrigin-RevId: 611589603 Change-Id: I9c27e689a2b319316332fd1e7c7c119047cd768c GitOrigin-RevId: 44632ca9f0d11ba6330bed45ba2c05222b8edd66
- Loading branch information
1 parent
3b8f0b4
commit 407b98c
Showing
2 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright 2024 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Custom lazy loader definition for xmanager sub-package __init__.py files.""" | ||
import dataclasses | ||
import importlib | ||
import types | ||
from typing import Any, Callable, Optional, Sequence | ||
|
||
|
||
@dataclasses.dataclass | ||
class XManagerAPI: | ||
"""Dataclass for XManager sub-package APIs. | ||
Attributes: | ||
module: An XManager submodule, which will be the exposed API if no symbol is | ||
provided. | ||
symbol: An optional symbol to expose from the given module. If provided, the | ||
API will be module.symbol. | ||
alias: An optional alias to rename the API. | ||
""" | ||
|
||
module: str | ||
symbol: Optional[str] = None | ||
alias: Optional[str] = None | ||
|
||
|
||
class XManagerLazyLoader: | ||
"""Custom lazy loader for xmanager sub-package __init__.py files.""" | ||
|
||
def __init__(self, subpackage_name: str, apis: Sequence[XManagerAPI]): | ||
"""Initializes the XManagerLazyLoader. | ||
Args: | ||
subpackage_name: The name of the current xmanager sub-package (i.e. the | ||
__name__ attribute of the current xmanager sub-package). | ||
apis: A list of XManagerAPIs to expose from the XManager sub-package. | ||
""" | ||
self.subpackage_name = subpackage_name | ||
self.apis = apis | ||
self._loaded_attrs = {} | ||
self._name_to_api: dict[str, XManagerAPI] = {} | ||
for api in self.apis: | ||
if api.alias: | ||
name = api.alias | ||
elif api.symbol: | ||
name = api.symbol | ||
else: | ||
name = api.module.split(".")[-1] # module name | ||
self._name_to_api[name] = api | ||
|
||
def get_module_all(self) -> list[str]: | ||
"""Returns __all__ for the xmanager sub-package __init__.py file.""" | ||
return list(self._name_to_api.keys()) | ||
|
||
def get_module_dir(self) -> Callable[[], list[str]]: | ||
"""Returns __dir__ for the xmanager sub-package __init__.py file.""" | ||
return lambda: sorted(self._name_to_api.keys()) | ||
|
||
def get_module_getattr( | ||
self, | ||
) -> Callable[[str], types.ModuleType | Any | None]: | ||
"""Returns __getattr__ for the xmanager sub-package __init__.py file.""" | ||
|
||
def _module_getattr(name: str) -> types.ModuleType | Any | None: | ||
if name in self._loaded_attrs: | ||
return self._loaded_attrs[name] | ||
if name in self._name_to_api: | ||
api = self._name_to_api[name] | ||
if api.symbol: | ||
module = importlib.import_module(api.module) | ||
attr = getattr(module, api.symbol) | ||
self._loaded_attrs[name] = attr | ||
return attr | ||
else: | ||
module = importlib.import_module(api.module) | ||
self._loaded_attrs[name] = module | ||
return module | ||
raise AttributeError( | ||
f"module {self.subpackage_name!r} has no attribute {name!r}" | ||
) | ||
|
||
return _module_getattr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2024 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test for custom lazy loader definition for xmanager sub-package __init__.py files.""" | ||
|
||
import unittest | ||
from xmanager import module_lazy_loader | ||
|
||
|
||
class ModuleLazyLoaderTest(unittest.TestCase): | ||
test_lazy_loader = module_lazy_loader.XManagerLazyLoader( | ||
__name__, | ||
apis=[ | ||
module_lazy_loader.XManagerAPI( | ||
module="xmanager.module_lazy_loader", | ||
symbol="XManagerAPI", | ||
alias="boo", | ||
), | ||
module_lazy_loader.XManagerAPI( | ||
module="xmanager.module_lazy_loader", | ||
symbol="XManagerAPI", | ||
), | ||
module_lazy_loader.XManagerAPI( | ||
module="xmanager.module_lazy_loader", alias="baz" | ||
), | ||
module_lazy_loader.XManagerAPI( | ||
module="xmanager.module_lazy_loader", | ||
), | ||
], | ||
) | ||
|
||
def test_all(self): | ||
self.assertCountEqual( | ||
self.test_lazy_loader.get_module_all(), | ||
["boo", "XManagerAPI", "baz", "module_lazy_loader"], | ||
) | ||
|
||
def test_dir(self): | ||
self.assertCountEqual( | ||
self.test_lazy_loader.get_module_dir()(), | ||
["boo", "XManagerAPI", "baz", "module_lazy_loader"], | ||
) | ||
|
||
def test_getattr(self): | ||
local_getattr = self.test_lazy_loader.get_module_getattr() | ||
self.assertEqual(local_getattr("boo"), module_lazy_loader.XManagerAPI) | ||
self.assertEqual( | ||
local_getattr("XManagerAPI"), module_lazy_loader.XManagerAPI | ||
) | ||
self.assertEqual(local_getattr("baz"), module_lazy_loader) | ||
self.assertEqual(local_getattr("module_lazy_loader"), module_lazy_loader) | ||
self.assertRaises(AttributeError, local_getattr, "this_attr_does_not_exist") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |