diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index eaf8cdbd..ad2b96f5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -160,6 +160,12 @@ def run(args: argparse.Namespace) -> None: args.E0s != "average" ), "average atomic energies cannot be used for multiheads finetuning" # check that the foundation model has a single head, if not, use the first head + args.lr = 0.001 + args.ema = True + args.ema_decay = 0.999 + logging.info( + "Using multiheads finetuning mode, setting learning rate to 0.001 and EMA to True" + ) if hasattr(model_foundation, "heads"): if len(model_foundation.heads) > 1: logging.warning(