Skip to content

Commit

Permalink
Merge branch 'ACEsuit:develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Hongyu-yu authored Oct 29, 2024
2 parents 1c5b0fd + ae84fa2 commit 1a98dce
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ We are happy to accept pull requests under an [MIT license](https://choosealicen

If you use this code, please cite our papers:

```text
```bibtex
@inproceedings{Batatia2022mace,
title={{MACE}: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields},
author={Ilyes Batatia and David Peter Kovacs and Gregor N. C. Simm and Christoph Ortner and Gabor Csanyi},
Expand Down
71 changes: 42 additions & 29 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,15 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
model_path = model
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

Expand Down Expand Up @@ -173,36 +180,42 @@ def mace_off(
MACECalculator: trained on the MACE-OFF23 dataset
"""
try:
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
if model in (None, "small", "medium", "large") or str(model).startswith(
"https:"
):
urls = dict(
small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true",
medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true",
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
print(
"The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license."
)
print(
"ASL is based on the Gnu Public License, but does not permit commercial use"
)
urllib.request.urlretrieve(checkpoint_url, cached_model_path)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}"
print(msg)
else:
if not Path(model).exists():
raise FileNotFoundError(f"{model} not found locally")
except Exception as exc:
raise RuntimeError("Model download failed") from exc
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
20 changes: 10 additions & 10 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def valid_err_log(
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A"
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -70,7 +70,7 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3",
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -80,7 +80,7 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV",
)
elif (
log_errors == "PerAtomMAEstressvirials"
Expand All @@ -90,7 +90,7 @@ def valid_err_log(
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3"
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
Expand All @@ -100,37 +100,37 @@ def valid_err_log(
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV"
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A",
)
elif log_errors == "PerAtomMAE":
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
)
elif log_errors == "TotalMAE":
error_e = eval_metrics["mae_e"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
)
elif log_errors == "DipoleRMSE":
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye",
)
elif log_errors == "EnergyDipoleRMSE":
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye",
)


Expand Down
12 changes: 12 additions & 0 deletions tests/test_foundations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
import pytest
import torch
Expand All @@ -13,6 +15,14 @@
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.utils import AtomicNumberTable

MODEL_PATH = (
Path(__file__).parent.parent
/ "mace"
/ "calculators"
/ "foundations_models"
/ "2023-12-03-mace-mp.model"
)

torch.set_default_dtype(torch.float64)
config = data.Configuration(
atomic_numbers=molecule("H2COH").numbers,
Expand Down Expand Up @@ -172,9 +182,11 @@ def test_multi_reference():
mace_mp(model="small", device="cpu", default_dtype="float64").models[0],
mace_mp(model="medium", device="cpu", default_dtype="float64").models[0],
mace_mp(model="large", device="cpu", default_dtype="float64").models[0],
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0],
mace_off(model="small", device="cpu", default_dtype="float64").models[0],
mace_off(model="medium", device="cpu", default_dtype="float64").models[0],
mace_off(model="large", device="cpu", default_dtype="float64").models[0],
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0],
],
)
def test_extract_config(model):
Expand Down

0 comments on commit 1a98dce

Please sign in to comment.