Skip to content

Commit

Permalink
Fix setting devices (stfc#208)
Browse files Browse the repository at this point in the history
* Fix setting device for MACE-MP and MACE-OFF

* Update list of valid PyTorch devices

* Add test for invalid device

* Fix device for CHGNet tests
  • Loading branch information
ElliottKasoar authored Jul 12, 2024
1 parent 3b33a77 commit c1653dd
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PostProcessKwargs(TypedDict, total=False):

# Janus specific
Architectures = Literal["mace", "mace_mp", "mace_off", "m3gnet", "chgnet"]
Devices = Literal["cpu", "cuda", "mps"]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]


Expand Down
9 changes: 7 additions & 2 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- https://github.com/Quantum-Accelerators/quacc.git
"""

from typing import get_args

from ase.calculators.calculator import Calculator
import torch

Expand Down Expand Up @@ -47,6 +49,9 @@ def choose_calculator(
if "model" in kwargs and "model_paths" in kwargs:
raise ValueError("Please specify either `model` or `model_paths`")

if not device in get_args(Devices):
raise ValueError(f"`device` must be one of: {get_args(Devices)}")

if architecture == "mace":
from mace import __version__
from mace.calculators import MACECalculator
Expand All @@ -66,7 +71,7 @@ def choose_calculator(
# Otherwise, take `model_paths` if specified, then default to "small"
kwargs.setdefault("model", kwargs.pop("model_paths", "small"))
kwargs.setdefault("default_dtype", "float64")
calculator = mace_mp(**kwargs)
calculator = mace_mp(device=device, **kwargs)

elif architecture == "mace_off":
from mace import __version__
Expand All @@ -76,7 +81,7 @@ def choose_calculator(
# Otherwise, take `model_paths` if specified, then default to "small"
kwargs.setdefault("model", kwargs.pop("model_paths", "small"))
kwargs.setdefault("default_dtype", "float64")
calculator = mace_off(**kwargs)
calculator = mace_off(device=device, **kwargs)

elif architecture == "m3gnet":
from matgl import __version__, load_model
Expand Down
2 changes: 1 addition & 1 deletion tests/test_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_no_optimize(tmp_path):
)


test_data_potentials = [("m3gnet", "cpu"), ("chgnet", "")]
test_data_potentials = [("m3gnet", "cpu"), ("chgnet", "cpu")]


@pytest.mark.parametrize("arch, device", test_data_potentials)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
("mace_off", "cpu", {"model": "small"}),
]

test_data_extras = [("m3gnet", "cpu"), ("chgnet", "")]
test_data_extras = [("m3gnet", "cpu"), ("chgnet", "cpu")]


@pytest.mark.parametrize("architecture, device, kwargs", test_data_mace)
Expand Down Expand Up @@ -51,3 +51,10 @@ def test_model_model_paths():
model=MODEL_PATH,
model_paths=MODEL_PATH,
)


@pytest.mark.parametrize("architecture", ["mace_mp", "mace_off"])
def test_invalid_device(architecture):
"""Test error raised for invalid device is specified."""
with pytest.raises(ValueError):
choose_calculator(architecture=architecture, device="invalid")
2 changes: 1 addition & 1 deletion tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_no_atoms_or_path():

test_data_potentials = [
("m3gnet", "cpu", -26.729949951171875),
("chgnet", "", -29.331436157226562),
("chgnet", "cpu", -29.331436157226562),
]


Expand Down

0 comments on commit c1653dd

Please sign in to comment.