From 6288b026e4335208bf35ca72707393e2936ad577 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:40:28 +0100 Subject: [PATCH] fix model_config lit --- mace/cli/run_train.py | 12 +++--------- mace/tools/scripts_utils.py | 3 +-- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 17b39667..1db77a06 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -350,10 +350,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" - ) + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -373,11 +370,7 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } - atomic_energies_dict_pt = atomic_energies_dict["pt_head"] - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}" - ) - + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True @@ -560,6 +553,7 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] args.model = "FoundationMACE" model_config_foundation["heads"] = heads + model_config = model_config_foundation logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") logging.info( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 025b3453..f44390a6 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -19,8 +19,7 @@ from prettytable import PrettyTable from torch.optim.swa_utils import SWALR, AveragedModel -from mace import data, modules -from mace import tools +from mace import data, modules, tools from mace.tools import evaluate from mace.tools.train import SWAContainer