diff --git a/pyproject.toml b/pyproject.toml index 2027f0562d..d63d0919c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,9 @@ ECM_Example = "pybamm.input.parameters.ecm.example_set:get_parameter_values" MSMR_Example = "pybamm.input.parameters.lithium_ion.MSMR_example_set:get_parameter_values" Chayambuka2022 = "pybamm.input.parameters.sodium_ion.Chayambuka2022:get_parameter_values" +[project.entry-points."pybamm_models"] +SPM = "pybamm.models.full_battery_models.lithium_ion.spm:SPM" + [tool.setuptools] include-package-data = true diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index b466c3896b..efd7561670 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -124,7 +124,6 @@ from .parameters.lead_acid_parameters import LeadAcidParameters from .parameters.ecm_parameters import EcmParameters from .parameters.size_distribution_parameters import * -from .parameters.parameter_sets import parameter_sets # Mesh and Discretisation classes from .discretisations.discretisation import Discretisation @@ -201,6 +200,9 @@ # Pybamm Data manager using pooch from .pybamm_data import DataLoader +# Pybamm entry point API for parameter_sets and models +from .dispatch import parameter_sets, Model + # Fix Casadi import import os import pathlib @@ -233,6 +235,7 @@ "util", "version", "pybamm_data", + "dispatch", ] config.generate() diff --git a/src/pybamm/dispatch/__init__.py b/src/pybamm/dispatch/__init__.py new file mode 100644 index 0000000000..4890c0cf60 --- /dev/null +++ b/src/pybamm/dispatch/__init__.py @@ -0,0 +1,7 @@ +from .entry_points import parameter_sets, models, Model + +__all__ = [ + "parameter_sets", + "models", + "Model", +] diff --git a/src/pybamm/dispatch/entry_points.py b/src/pybamm/dispatch/entry_points.py new file mode 100644 index 0000000000..a843d3e896 --- /dev/null +++ b/src/pybamm/dispatch/entry_points.py @@ -0,0 +1,143 @@ +import sys +import warnings +import importlib.metadata +import textwrap +from collections.abc import Mapping +from typing import Callable + + +class EntryPoint(Mapping): + """ + Access via :py:data:`pybamm.parameter_sets` for parameter_sets + Access via :py:data:`pybamm.Model` for Models + + Examples + -------- + Listing available parameter sets: + >>> import pybamm + >>> list(pybamm.parameter_sets) + ['Ai2020', 'Chayambuka2022', ...] + >>> list(pybamm.dispatch.models) + ['SPM'] + + Get the docstring for a parameter set/model: + + + >>> print(pybamm.parameter_sets.get_docstring("Ai2020")) + + Parameters for the Enertech cell (Ai2020), from the papers :footcite:t:`Ai2019`, + :footcite:t:`Rieger2016` and references therein. + ... + + See also: :ref:`adding-parameter-sets` + + >>> print(pybamm.dispatch.models.get_docstring("SPM")) + + Single Particle Model (SPM) of a lithium-ion battery, from + :footcite:t:`Marquis2019`. + See :class:`pybamm.lithium_ion.BaseModel` for more details. + ... + """ + + _instances = 0 + + def __init__(self, group): + """Dict of entry points for parameter sets or models, lazily load entry points as""" + if not hasattr( + self, "initialized" + ): # Ensure __init__ is called once per instance + self.initialized = True + EntryPoint._instances += 1 + self._all_entries = dict() + self.group = group + for entry_point in self.get_entries(self.group): + self._all_entries[entry_point.name] = entry_point + + @staticmethod + def get_entries(group_name): + """Wrapper for the importlib version logic""" + if sys.version_info < (3, 10): # pragma: no cover + return importlib.metadata.entry_points()[group_name] + else: + return importlib.metadata.entry_points(group=group_name) + + def __new__(cls, group): + """Ensure only two instances of entry points exist, one for parameter sets and the other for models""" + if EntryPoint._instances < 2: + cls.instance = super().__new__(cls) + return cls.instance + + def __getitem__(self, key) -> dict: + return self._load_entry_point(key)() + + def _load_entry_point(self, key) -> Callable: + """Check that ``key`` is a registered ``parameter_sets`` or ``models` , + and return the entry point for the parameter set/model, loading it needed.""" + if key not in self._all_entries: + raise KeyError(f"Unknown parameter set or model: {key}") + ps = self._all_entries[key] + try: + ps = self._all_entries[key] = ps.load() + except AttributeError: + pass + return ps + + def __iter__(self): + return self._all_entries.__iter__() + + def __len__(self) -> int: + return len(self._all_entries) + + def get_docstring(self, key): + """Return the docstring for the ``key`` parameter set or model""" + return textwrap.dedent(self._load_entry_point(key).__doc__) + + def __getattribute__(self, name): + try: + # For backwards compatibility, parameter sets that used to be defined in + # this file now return the name as a string, which will load the same + # parameter set as before when passed to `ParameterValues` + # Bypass the overloaded __getitem__ and __iter__ to avoid recursion + _all_entries = super().__getattribute__("_all_entries") + if name in _all_entries: + msg = ( + f"Parameter sets should be called directly by their name ({name}), " + f"instead of via pybamm.parameter_sets (pybamm.parameter_sets.{name})." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + return name + except AttributeError: + pass # Handle the attribute error normally + + return super().__getattribute__(name) + + +#: Singleton Instance of :class:EntryPoint initialised with pybamm_parameter_sets""" +parameter_sets = EntryPoint(group="pybamm_parameter_sets") + +#: Singleton Instance of :class:EntryPoint initialised with pybamm_models""" +models = EntryPoint(group="pybamm_models") + + +def Model(model: str): # doctest: +SKIP + """ + Returns the loaded model object + + Parameters + ---------- + model : str + The model name or author name of the model mentioned at the model entry point. + Returns + ------- + pybamm.model + Model object of the initialised model. + Examples + -------- + Listing available models: + >>> import pybamm + >>> list(pybamm.dispatch.models) + ['SPM'] + >>> pybamm.Model('SPM') # doctest: +SKIP + + """ + return models[model] diff --git a/src/pybamm/parameters/parameter_sets.py b/src/pybamm/parameters/parameter_sets.py deleted file mode 100644 index 22b476f4e0..0000000000 --- a/src/pybamm/parameters/parameter_sets.py +++ /dev/null @@ -1,100 +0,0 @@ -import sys -import warnings -import importlib.metadata -import textwrap -from collections.abc import Mapping -from typing import Callable - - -class ParameterSets(Mapping): - """ - Dict-like interface for accessing registered pybamm parameter sets. - Access via :py:data:`pybamm.parameter_sets` - - Examples - -------- - Listing available parameter sets: - - - >>> import pybamm - >>> list(pybamm.parameter_sets) - ['Ai2020', 'Chayambuka2022', ...] - - Get the docstring for a parameter set: - - - >>> print(pybamm.parameter_sets.get_docstring("Ai2020")) - - Parameters for the Enertech cell (Ai2020), from the papers :footcite:t:`Ai2019`, - :footcite:t:`Rieger2016` and references therein. - ... - - See also: :ref:`adding-parameter-sets` - - """ - - def __init__(self): - # Dict of entry points for parameter sets, lazily load entry points as - self.__all_parameter_sets = dict() - for entry_point in self.get_entries("pybamm_parameter_sets"): - self.__all_parameter_sets[entry_point.name] = entry_point - - @staticmethod - def get_entries(group_name): - # Wrapper for the importlib version logic - if sys.version_info < (3, 10): # pragma: no cover - return importlib.metadata.entry_points()[group_name] - else: - return importlib.metadata.entry_points(group=group_name) - - def __new__(cls): - """Ensure only one instance of ParameterSets exists""" - if not hasattr(cls, "instance"): - cls.instance = super().__new__(cls) - return cls.instance - - def __getitem__(self, key) -> dict: - return self.__load_entry_point__(key)() - - def __load_entry_point__(self, key) -> Callable: - """Check that ``key`` is a registered ``pybamm_parameter_sets``, - and return the entry point for the parameter set, loading it needed. - """ - if key not in self.__all_parameter_sets: - raise KeyError(f"Unknown parameter set: {key}") - ps = self.__all_parameter_sets[key] - try: - ps = self.__all_parameter_sets[key] = ps.load() - except AttributeError: - pass - return ps - - def __iter__(self): - return self.__all_parameter_sets.__iter__() - - def __len__(self) -> int: - return len(self.__all_parameter_sets) - - def get_docstring(self, key): - """Return the docstring for the ``key`` parameter set""" - return textwrap.dedent(self.__load_entry_point__(key).__doc__) - - def __getattribute__(self, name): - try: - return super().__getattribute__(name) - except AttributeError as error: - # For backwards compatibility, parameter sets that used to be defined in - # this file now return the name as a string, which will load the same - # parameter set as before when passed to `ParameterValues` - if name in self: - msg = ( - f"Parameter sets should be called directly by their name ({name}), " - f"instead of via pybamm.parameter_sets (pybamm.parameter_sets.{name})." - ) - warnings.warn(msg, DeprecationWarning, stacklevel=2) - return name - raise error - - -#: Singleton Instance of :class:ParameterSets """ -parameter_sets = ParameterSets() diff --git a/tests/unit/test_parameters/test_parameter_sets_class.py b/tests/unit/test_parameters/test_entry_points.py similarity index 61% rename from tests/unit/test_parameters/test_parameter_sets_class.py rename to tests/unit/test_parameters/test_entry_points.py index 342cf127aa..f79aa39d7b 100644 --- a/tests/unit/test_parameters/test_parameter_sets_class.py +++ b/tests/unit/test_parameters/test_entry_points.py @@ -31,9 +31,35 @@ def test_all_registered(self): def test_get_docstring(self): """Test that :meth:`pybamm.parameter_sets.get_doctstring` works""" docstring = pybamm.parameter_sets.get_docstring("Marquis2019") + print(docstring) assert re.search("Parameters for a Kokam SLPB78205130H cell", docstring) def test_iter(self): """Test that iterating `pybamm.parameter_sets` iterates over keys""" for k in pybamm.parameter_sets: assert isinstance(k, str) + + +class TestModelEntryPoints: + def test_all_registered(self): + """Check that all models have been registered with the + ``pybamm_models`` entry point""" + known_entry_points = set( + ep.name for ep in pybamm.dispatch.models.get_entries("pybamm_models") + ) + assert set(pybamm.dispatch.models.keys()) == known_entry_points + assert len(known_entry_points) == len(pybamm.dispatch.models) + + def test_get_docstring(self): + """Test that :meth:`pybamm.dispatch.models.get_doctstring` works""" + docstring = pybamm.dispatch.models.get_docstring("SPM") + print(docstring) + assert re.search( + "Single Particle Model", + docstring, + ) + + def test_iter(self): + """Test that iterating `pybamm.models` iterates over keys""" + for k in pybamm.dispatch.models: + assert isinstance(k, str)