From 5016abfe4458192675f0983e7f5b52332de21fbe Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 2 May 2024 17:45:57 +0100 Subject: [PATCH] fix urls and extract E0s --- mace/cli/run_train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 6b4d5290..cb098eb6 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -43,6 +43,7 @@ load_foundations_elements, extract_config_mace_model, ) +from mace.tools.utils import AtomicNumberTable def main() -> None: @@ -187,8 +188,8 @@ def main() -> None: args.heads = heads logging.info(f"Using heads: {heads}") try: - checkpoint_url = "https://tinyurl.com/mw2wetc5" - descriptors_url = "https://tinyurl.com/mpe7br4d" + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c @@ -309,9 +310,21 @@ def main() -> None: else: atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) if args.multiheads_finetuning: - with open(r"mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: - E0s_mp = json.load(file) - atomic_energies_dict["pbe_mp"] = {z: E0s_mp["pbe"][f"{z}"] for z in z_table.zs} + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + logging.info( + "Using atomic energies from foundation model for multiheads finetuning" + ) + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + atomic_energies_dict["pbe_mp"] = { + z: model_foundation.atomic_energies_fn.atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } if args.model == "AtomicDipolesMACE": atomic_energies = None