diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index ad54e2f0..d937446c 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -146,6 +146,11 @@ def _build_model( args, model_config, model_config_foundation, heads ): # pylint: disable=too-many-return-statements if args.model == "MACE": + if args.interaction_first not in [ + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + ]: + args.interaction_first = "RealAgnosticInteractionBlock" return modules.ScaleShiftMACE( **model_config, pair_repulsion=args.pair_repulsion,