diff --git a/tests/test_run_train.py b/tests/test_run_train.py index bb599a98..80d7eaa7 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -686,15 +686,19 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - Es = [] for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) at.calc = calc Es.append(at.get_potential_energy()) + print("Es", Es) # from a run on 20/08/2024 on commit ref_Es = [