Skip to content

Commit

Permalink
add sevennet support
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena committed Jul 30, 2024
1 parent f2cd988 commit 17d4180
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 35 deletions.
21 changes: 19 additions & 2 deletions docs/source/developer_guide/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ Converting ``model_path`` into ``path`` is a minimum requirement, but we also ai
- If ``model_path`` is ``None``, we use the ALIGNN's ``default_path``

.. note::
``model_path`` will already be a ``pathlib.Path`` object, if the path exists.
``model_path`` will already be a ``pathlib.Path`` object, if the path exists. you may want to cast it back to a string
```str(model_path)``` if a string is needed

To ensure that the calculator does not receive multiple versions of keywords, it's also necessary to set ``model_path = path``, and remove ``path`` from ``kwargs``.

Expand All @@ -104,7 +105,12 @@ In addition to setting the calculator, ``__version__`` must also imported here,

Tests must be added to ensure that, at a minimum, the new calculator allows an MLIP to be loaded correctly, and that an energy can be calculated.

This can be done by adding the appropriate data as tuples to the ``pytest.mark.parametrize`` lists in the ``tests.test_mlip_calculators`` and ``tests.test_single_point`` modules.
This can be done by adding the appropriate data as tuples to the ``pytest.mark.parametrize`` lists in the ``tests.test_mlip_calculators`` and ``tests.test_single_point`` modules
that reside in files ``tests/test_mlip_calculators.py``` and ``tests/test_single_point.py``, respectively.


Load models - success
^^^^^^^^^^^^^^^^^^^^^

For ``tests.test_mlip_calculators``, ``architecture``, ``device`` and accepted forms of ``model_path`` should be tested, ensuring that the calculator and its version are correctly set::

Expand All @@ -121,28 +127,39 @@ For ``tests.test_mlip_calculators``, ``architecture``, ``device`` and accepted f
)
def test_extra_mlips(architecture, device, kwargs):

not all models may support empty paths, so for some you may want to remove the ``("alignn", "cpu", {})`` test.

Load models - failure
^^^^^^^^^^^^^^^^^^^^^

It is also useful to test that ``model_path``, and ``model`` or and the "standard" MLIP calculator parameter (``path``) cannot be defined simultaneously::

@pytest.mark.extra_mlips
@pytest.mark.parametrize(
"kwargs",
[
{
"architecture": "alignn",
"model_path": "tests/models/v5.27.2024/best_model.pt",
"model": "tests/models/v5.27.2024/best_model.pt",
},
{
"architecture": "alignn",
"model_path": "tests/models/v5.27.2024/best_model.pt",
"path": "tests/models/v5.27.2024/best_model.pt",
},
],
)
def test_extra_mlips_invalid(kwargs):

Test correctness
^^^^^^^^^^^^^^^^

For ``tests.test_single_point``, ``architecture``, ``device``, and the potential energy of NaCl predicted by the MLIP should be defined, ensuring that calculations can be performed::

test_extra_mlips_data = [("alignn", "cpu", -11.148092269897461)]


Running these tests requires an additional flag to be passed to ``pytest``::

pytest -v --run-extra-mlips
Expand Down
4 changes: 3 additions & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class CorrelationKwargs(TypedDict, total=True):


# Janus specific
Architectures = Literal["mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn"]
Architectures = Literal[
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet"
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]

Expand Down
26 changes: 25 additions & 1 deletion janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def choose_calculator(

# No default `model_path`
if model_path is None:
raise ValueError("Please specify `model_path`")
raise ValueError(
"Please specify `model_path` there is no "
f"default model for {architecture}"
)
# Default to float64 precision
kwargs.setdefault("default_dtype", "float64")

Expand Down Expand Up @@ -203,6 +206,27 @@ def choose_calculator(

calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs)

elif architecture == "sevennet":
from sevenn.sevennet_calculator import SevenNetCalculator

__version__ = "0.0.0"

if model_path is None or model_path.name == "":
raise ValueError(
"Please specify `model_path` there is no "
f"default model for {architecture}"
)
if isinstance(model_path, Path):
model = str(model_path)
elif isinstance(model_path, str):
model = model_path
else:
model = None

kwargs.setdefault("file_type", "checkpoint")
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {architecture=}. Suported architectures "
Expand Down
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,23 @@ chgnet = "0.3.8"
dgl = "2.1.0"
mace-torch = "0.3.6"
matgl = "1.1.2"
numpy = "^1.26.4"
numpy = "1.26.4"
pyyaml = "^6.0.1"
typer = "^0.9.0"
typer-config = "^1.4.0"
phonopy = "^2.23.1"
seekpath = "^1.9.7"
spglib = "^2.3.0"
torch-dftd = "^0.4.0"
torch-dftd = "0.4.0"
codecarbon = "^2.5.0"
alignn = { version = "2024.5.27", optional = true }
sevenn = { version = "0.9.3", optional = true }
torch_scatter = { version = "^2.1.2", optional = true }
torch_geometric = { version = "^2.5.3", optional = true }

[tool.poetry.group.extra-mlips]
optional = true
[tool.poetry.group.extra-mlips.dependencies]
alignn = "^2024.5.27"
[tool.poetry.extras]
alignnff = ["alignn"]
sevennet = ["sevenn", "torch_scatter", "torch_geometric"]

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
Binary file added tests/models/sevennet_0.pth
Binary file not shown.
35 changes: 28 additions & 7 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
CHGNET_PATH = MODEL_PATH / "chgnet_0.3.0_e29f68s314m37.pth.tar"
CHGNET_MODEL = CHGNet.from_file(path=CHGNET_PATH)

SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"


@pytest.mark.parametrize(
"architecture, device, kwargs",
Expand Down Expand Up @@ -103,9 +105,12 @@ def test_invalid_device(architecture):
[
("alignn", "cpu", {}),
("alignn", "cpu", {"model_path": MODEL_PATH / "v5.27.2024"}),
("alignn", "cpu", {"model_path": MODEL_PATH / "v5.27.2024/best_model.pt"}),
("alignn", "cpu", {"model_path": MODEL_PATH / "v5.27.2024" / "best_model.pt"}),
("alignn", "cpu", {"model": "alignnff_wt10"}),
("alignn", "cpu", {"path": MODEL_PATH / "v5.27.2024"}),
("sevennet", "cpu", {"model": SEVENNET_PATH}),
("sevennet", "cpu", {"path": SEVENNET_PATH}),
("sevennet", "cpu", {"model_path": SEVENNET_PATH}),
],
)
def test_extra_mlips(architecture, device, kwargs):
Expand All @@ -123,16 +128,32 @@ def test_extra_mlips(architecture, device, kwargs):
"kwargs",
[
{
"model_path": MODEL_PATH / "v5.27.2024/best_model.pt",
"model": MODEL_PATH / "v5.27.2024/best_model.pt",
"architecture": "alignn",
"model_path": MODEL_PATH / "v5.27.2024" / "best_model.pt",
"model": MODEL_PATH / "v5.27.2024" / "best_model.pt",
},
{
"architecture": "alignn",
"model_path": MODEL_PATH / "v5.27.2024" / "best_model.pt",
"path": MODEL_PATH / "v5.27.2024" / "best_model.pt",
},
{
"architecture": "sevennet",
"model_path": SEVENNET_PATH,
"path": SEVENNET_PATH,
},
{
"architecture": "sevennet",
"model_path": SEVENNET_PATH,
"model": SEVENNET_PATH,
},
{
"model_path": MODEL_PATH / "v5.27.2024/best_model.pt",
"path": MODEL_PATH / "v5.27.2024/best_model.pt",
"architecture": "sevennet",
"model_path": "",
},
],
)
def test_extra_mlips_invalid(kwargs):
def test_alignff_invalid(kwargs):
"""Test error raised if multiple model paths defined for extra MLIPs."""
with pytest.raises(ValueError):
choose_calculator(architecture="alignn", **kwargs)
choose_calculator(**kwargs)
43 changes: 25 additions & 18 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from tests.utils import read_atoms

DATA_PATH = Path(__file__).parent / "data"
MODEL_PATH = Path(__file__).parent / "models" / "mace_mp_small.model"
MODEL_PATH = Path(__file__).parent / "models"

MACE_PATH = MODEL_PATH / "mace_mp_small.model"
SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"

test_data = [
(DATA_PATH / "benzene.xyz", -76.0605725422795, "energy", "energy", {}, None),
Expand All @@ -34,7 +37,7 @@ def test_potential_energy(
struct_path, expected, properties, prop_key, calc_kwargs, idx
):
"""Test single point energy using MACE calculators."""
calc_kwargs["model"] = MODEL_PATH
calc_kwargs["model"] = MACE_PATH
single_point = SinglePoint(
struct_path=struct_path, architecture="mace", calc_kwargs=calc_kwargs
)
Expand All @@ -59,7 +62,7 @@ def test_single_point_none():
single_point = SinglePoint(
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

results = single_point.run()
Expand All @@ -72,7 +75,7 @@ def test_single_point_clean():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

results = single_point.run()
Expand All @@ -86,7 +89,7 @@ def test_single_point_traj():
single_point = SinglePoint(
struct_path=DATA_PATH / "benzene-traj.xyz",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

assert len(single_point.struct) == 2
Expand All @@ -110,7 +113,7 @@ def test_single_point_write():
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert "mace_forces" not in single_point.struct.arrays

Expand Down Expand Up @@ -142,7 +145,7 @@ def test_single_point_write_kwargs(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert "mace_forces" not in single_point.struct.arrays

Expand All @@ -159,7 +162,7 @@ def test_single_point_molecule(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

assert isfinite(single_point.run("energy")["energy"]).all()
Expand All @@ -177,7 +180,7 @@ def test_invalid_prop():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
with pytest.raises(NotImplementedError):
single_point.run("invalid")
Expand All @@ -190,7 +193,7 @@ def test_atoms():
struct=struct,
struct_name="NaCl",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "NaCl"
assert single_point.run("energy")["energy"] < 0
Expand All @@ -202,7 +205,7 @@ def test_default_atoms_name():
single_point = SinglePoint(
struct=struct,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "Cl4Na4"

Expand All @@ -213,7 +216,7 @@ def test_default_path_name():
single_point = SinglePoint(
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "NaCl"

Expand All @@ -225,7 +228,7 @@ def test_path_specify_name():
struct_path=struct_path,
struct_name="example_name",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "example_name"

Expand All @@ -239,7 +242,7 @@ def test_atoms_and_path():
struct=struct,
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)


Expand All @@ -248,7 +251,7 @@ def test_no_atoms_or_path():
with pytest.raises(ValueError):
SinglePoint(
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)


Expand All @@ -270,17 +273,21 @@ def test_mlips(arch, device, expected_energy):
assert energy == pytest.approx(expected_energy)


test_extra_mlips_data = [("alignn", "cpu", -11.148092269897461)]
test_extra_mlips_data = [
("alignn", "cpu", -11.148092269897461, {}),
("sevennet", "cpu", -27.061979293823242, {"model_path": SEVENNET_PATH}),
]


@pytest.mark.extra_mlips
@pytest.mark.parametrize("arch, device, expected_energy", test_extra_mlips_data)
def test_extra_mlips(arch, device, expected_energy):
@pytest.mark.parametrize("arch, device, expected_energy, kwargs", test_extra_mlips_data)
def test_extra_mlips_alignn(arch, device, expected_energy, kwargs):
"""Test single point energy using ALIGNN-FF calculator."""
single_point = SinglePoint(
struct_path=DATA_PATH / "NaCl.cif",
architecture=arch,
device=device,
**kwargs,
)
energy = single_point.run("energy")["energy"]
assert energy == pytest.approx(expected_energy)

0 comments on commit 17d4180

Please sign in to comment.