diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 51f27ab0..3ebddbed 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: checkpoint_url = ( urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") else model ) @@ -106,7 +106,7 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - if model in (None, "small", "medium", "large") or str(model).startswith( + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith( "https:" ): model_path = download_mace_mp_checkpoint(model)