diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index d627e859..eaf8cdbd 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -265,7 +265,7 @@ def run(args: argparse.Namespace) -> None: args.loss = "universal" if ( args.foundation_model in ["small", "medium", "large"] - or args.pt_train_file is None + or args.pt_train_file == "mp" ): logging.info( "Using foundation model for multiheads finetuning with Materials Project data"