Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified entry point for models and parameter sets #4490

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ ECM_Example = "pybamm.input.parameters.ecm.example_set:get_parameter_values"
MSMR_Example = "pybamm.input.parameters.lithium_ion.MSMR_example_set:get_parameter_values"
Chayambuka2022 = "pybamm.input.parameters.sodium_ion.Chayambuka2022:get_parameter_values"

[project.entry-points."pybamm_models"]
SPM = "pybamm.models.full_battery_models.lithium_ion.spm:SPM"

[tool.setuptools]
include-package-data = true

Expand Down
5 changes: 4 additions & 1 deletion src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@
from .parameters.lead_acid_parameters import LeadAcidParameters
from .parameters.ecm_parameters import EcmParameters
from .parameters.size_distribution_parameters import *
from .parameters.parameter_sets import parameter_sets

# Mesh and Discretisation classes
from .discretisations.discretisation import Discretisation
Expand Down Expand Up @@ -201,6 +200,9 @@
# Pybamm Data manager using pooch
from .pybamm_data import DataLoader

# Pybamm entry point API for parameter_sets and models
from .dispatch import parameter_sets, Model

# Fix Casadi import
import os
import pathlib
Expand Down Expand Up @@ -233,6 +235,7 @@
"util",
"version",
"pybamm_data",
"dispatch",
]

config.generate()
7 changes: 7 additions & 0 deletions src/pybamm/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .entry_points import parameter_sets, models, Model

__all__ = [
"parameter_sets",
"models",
"Model",
]
143 changes: 143 additions & 0 deletions src/pybamm/dispatch/entry_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import sys
import warnings
import importlib.metadata
import textwrap
from collections.abc import Mapping
from typing import Callable


class EntryPoint(Mapping):
"""
Access via :py:data:`pybamm.parameter_sets` for parameter_sets
Access via :py:data:`pybamm.Model` for Models

Examples
--------
Listing available parameter sets:
>>> import pybamm
>>> list(pybamm.parameter_sets)
['Ai2020', 'Chayambuka2022', ...]
>>> list(pybamm.dispatch.models)
['SPM']

Get the docstring for a parameter set/model:


>>> print(pybamm.parameter_sets.get_docstring("Ai2020"))
<BLANKLINE>
Parameters for the Enertech cell (Ai2020), from the papers :footcite:t:`Ai2019`,
:footcite:t:`Rieger2016` and references therein.
...

See also: :ref:`adding-parameter-sets`

>>> print(pybamm.dispatch.models.get_docstring("SPM"))
<BLANKLINE>
Single Particle Model (SPM) of a lithium-ion battery, from
:footcite:t:`Marquis2019`.
See :class:`pybamm.lithium_ion.BaseModel` for more details.
...
"""

_instances = 0

def __init__(self, group):
"""Dict of entry points for parameter sets or models, lazily load entry points as"""
if not hasattr(
self, "initialized"
): # Ensure __init__ is called once per instance
self.initialized = True
EntryPoint._instances += 1
self._all_entries = dict()
self.group = group
for entry_point in self.get_entries(self.group):
self._all_entries[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
"""Wrapper for the importlib version logic"""
if sys.version_info < (3, 10): # pragma: no cover
return importlib.metadata.entry_points()[group_name]
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls, group):
"""Ensure only two instances of entry points exist, one for parameter sets and the other for models"""
if EntryPoint._instances < 2:
cls.instance = super().__new__(cls)
return cls.instance

def __getitem__(self, key) -> dict:
return self._load_entry_point(key)()

def _load_entry_point(self, key) -> Callable:
"""Check that ``key`` is a registered ``parameter_sets`` or ``models` ,
and return the entry point for the parameter set/model, loading it needed."""
if key not in self._all_entries:
raise KeyError(f"Unknown parameter set or model: {key}")
ps = self._all_entries[key]
try:
ps = self._all_entries[key] = ps.load()
except AttributeError:
pass
return ps

def __iter__(self):
return self._all_entries.__iter__()

def __len__(self) -> int:
return len(self._all_entries)

def get_docstring(self, key):
"""Return the docstring for the ``key`` parameter set or model"""
return textwrap.dedent(self._load_entry_point(key).__doc__)

def __getattribute__(self, name):
try:
# For backwards compatibility, parameter sets that used to be defined in
# this file now return the name as a string, which will load the same
# parameter set as before when passed to `ParameterValues`
# Bypass the overloaded __getitem__ and __iter__ to avoid recursion
_all_entries = super().__getattribute__("_all_entries")
if name in _all_entries:
msg = (
f"Parameter sets should be called directly by their name ({name}), "
f"instead of via pybamm.parameter_sets (pybamm.parameter_sets.{name})."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return name
except AttributeError:
pass # Handle the attribute error normally

return super().__getattribute__(name)


#: Singleton Instance of :class:EntryPoint initialised with pybamm_parameter_sets"""
parameter_sets = EntryPoint(group="pybamm_parameter_sets")

#: Singleton Instance of :class:EntryPoint initialised with pybamm_models"""
models = EntryPoint(group="pybamm_models")


def Model(model: str): # doctest: +SKIP
"""
Returns the loaded model object

Parameters
----------
model : str
The model name or author name of the model mentioned at the model entry point.
Returns
-------
pybamm.model
Model object of the initialised model.
Examples
--------
Listing available models:
>>> import pybamm
>>> list(pybamm.dispatch.models)
['SPM']
>>> pybamm.Model('SPM') # doctest: +SKIP
<pybamm.models.full_battery_models.lithium_ion.spm.SPM object>
"""
return models[model]

Check warning on line 143 in src/pybamm/dispatch/entry_points.py

View check run for this annotation

Codecov / codecov/patch

src/pybamm/dispatch/entry_points.py#L143

Added line #L143 was not covered by tests
100 changes: 0 additions & 100 deletions src/pybamm/parameters/parameter_sets.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,35 @@ def test_all_registered(self):
def test_get_docstring(self):
"""Test that :meth:`pybamm.parameter_sets.get_doctstring` works"""
docstring = pybamm.parameter_sets.get_docstring("Marquis2019")
print(docstring)
assert re.search("Parameters for a Kokam SLPB78205130H cell", docstring)

def test_iter(self):
"""Test that iterating `pybamm.parameter_sets` iterates over keys"""
for k in pybamm.parameter_sets:
assert isinstance(k, str)


class TestModelEntryPoints:
def test_all_registered(self):
"""Check that all models have been registered with the
``pybamm_models`` entry point"""
known_entry_points = set(
ep.name for ep in pybamm.dispatch.models.get_entries("pybamm_models")
)
assert set(pybamm.dispatch.models.keys()) == known_entry_points
assert len(known_entry_points) == len(pybamm.dispatch.models)

def test_get_docstring(self):
"""Test that :meth:`pybamm.dispatch.models.get_doctstring` works"""
docstring = pybamm.dispatch.models.get_docstring("SPM")
print(docstring)
assert re.search(
"Single Particle Model",
docstring,
)

def test_iter(self):
"""Test that iterating `pybamm.models` iterates over keys"""
for k in pybamm.dispatch.models:
assert isinstance(k, str)
Loading