diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index e09b50c6..b43e43c3 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -216,10 +216,18 @@ def _get_potential_energy(self) -> MaybeList[float]: MaybeList[float] Potential energy of structure(s). """ + tag = f"{self.architecture}_energy" if isinstance(self.struct, list): - return [struct.get_potential_energy() for struct in self.struct] + energies = [struct.get_potential_energy() for struct in self.struct] + for energy, struct in zip(energies, self.struct): + struct.info[tag] = energy + struct.calc.results = {} + return energies - return self.struct.get_potential_energy() + energy = self.struct.get_potential_energy() + self.struct.info[tag] = energy + self.struct.calc.results = {} + return energy def _get_forces(self) -> MaybeList[ndarray]: """ @@ -230,10 +238,18 @@ def _get_forces(self) -> MaybeList[ndarray]: MaybeList[ndarray] Forces of structure(s). """ + tag = f"{self.architecture}_forces" if isinstance(self.struct, list): - return [struct.get_forces() for struct in self.struct] + forces = [struct.get_forces() for struct in self.struct] + for force, struct in zip(forces, self.struct): + struct.arrays[tag] = force + struct.calc.results = {} + return forces - return self.struct.get_forces() + force = self.struct.get_forces() + self.struct.arrays[tag] = force + self.struct.calc.results = {} + return force def _get_stress(self) -> MaybeList[ndarray]: """ @@ -244,10 +260,18 @@ def _get_stress(self) -> MaybeList[ndarray]: MaybeList[ndarray] Stress of structure(s). """ + tag = f"{self.architecture}_stress" if isinstance(self.struct, list): - return [struct.get_stress() for struct in self.struct] + stresses = [struct.get_stress() for struct in self.struct] + for stress, struct in zip(stresses, self.struct): + struct.info[tag] = stress + struct.calc.results = {} + return stresses - return self.struct.get_stress() + stress = self.struct.get_stress() + self.struct.info[tag] = stress + self.struct.calc.results = {} + return stress @staticmethod def _remove_invalid_props( @@ -272,8 +296,8 @@ def _remove_invalid_props( # Find any properties with non-finite values rm_keys = [ prop - for prop in struct.calc.results - if not isfinite(struct.calc.results[prop]).all() + for prop in properties + if (prop in struct.info) and (not isfinite(struct.info[prop]).all()) ] # Raise error if property was explicitly requested, otherwise remove @@ -282,7 +306,7 @@ def _remove_invalid_props( raise ValueError( f"'{prop}' contains non-finite values for this structure." ) - del struct.calc.results[prop] + del struct.info[prop] if prop in results: del results[prop] @@ -360,7 +384,9 @@ def run( results["stress"] = self._get_stress() # Remove meaningless values from results e.g. stress for non-periodic systems - self._clean_results(results, properties=properties) + self._clean_results( + results, properties=[f"{self.architecture}_{prop}" for prop in properties] + ) if self.logger: self.logger.info("Single point calculation complete") diff --git a/tests/test_single_point.py b/tests/test_single_point.py index 366fc066..4458fd7a 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -50,6 +50,8 @@ def test_potential_energy( raise ValueError(f"Invalid index: {idx}") else: assert results == pytest.approx(expected) + results_2 = single_point.struct.info["mace_energy"] + assert results_2 == pytest.approx(expected) def test_single_point_none(): @@ -76,7 +78,7 @@ def test_single_point_clean(): results = single_point.run() for prop in ["energy", "forces"]: assert prop in results - assert "stress" not in results + assert "mace_stress" not in results def test_single_point_traj(): @@ -91,6 +93,12 @@ def test_single_point_traj(): results = single_point.run("energy") assert results["energy"][0] == pytest.approx(-76.0605725422795) assert results["energy"][1] == pytest.approx(-74.80419118083256) + assert single_point.struct[0].info["mace_energy"] == pytest.approx( + -76.0605725422795 + ) + assert single_point.struct[1].info["mace_energy"] == pytest.approx( + -74.80419118083256 + ) def test_single_point_write(): @@ -104,13 +112,26 @@ def test_single_point_write(): architecture="mace", calc_kwargs={"model": MODEL_PATH}, ) - assert "forces" not in single_point.struct.calc.results + assert "mace_forces" not in single_point.struct.arrays single_point.run(write_results=True) - atoms = read_atoms(results_path) - assert atoms.get_potential_energy() is not None - assert "forces" in atoms.calc.results + assert "mace_forces" in atoms.arrays + assert atoms.info["mace_energy"] == pytest.approx(-27.035127799332745) + assert "mace_stress" in atoms.info + assert atoms.info["mace_stress"] == pytest.approx( + [ + -0.004783275999053391, + -0.004783275999053417, + -0.004783275999053412, + -2.3858882876234007e-19, + -5.02032761017409e-19, + -2.29070171362209e-19, + ] + ) + assert atoms.arrays["mace_forces"][0] == pytest.approx( + [4.11996826e-18, 1.79977561e-17, 1.80139537e-17] + ) def test_single_point_write_kwargs(tmp_path): @@ -123,12 +144,12 @@ def test_single_point_write_kwargs(tmp_path): architecture="mace", calc_kwargs={"model": MODEL_PATH}, ) - assert "forces" not in single_point.struct.calc.results + assert "mace_forces" not in single_point.struct.arrays single_point.run(write_results=True, write_kwargs={"filename": results_path}) atoms = read(results_path) - assert atoms.get_potential_energy() is not None - assert "forces" in atoms.calc.results + assert atoms.info["mace_energy"] is not None + assert "mace_forces" in atoms.arrays def test_single_point_write_nan(tmp_path): @@ -147,9 +168,9 @@ def test_single_point_write_nan(tmp_path): single_point.run(write_results=True, write_kwargs={"filename": results_path}) atoms = read(results_path) - assert atoms.get_potential_energy() is not None - assert "forces" in atoms.calc.results - assert "stress" not in atoms.calc.results + assert atoms.info["mace_energy"] is not None + assert "mace_forces" in atoms.arrays + assert "mace_stress" not in atoms.info def test_invalid_prop(): diff --git a/tests/test_singlepoint_cli.py b/tests/test_singlepoint_cli.py index 05bd105d..4003f82a 100644 --- a/tests/test_singlepoint_cli.py +++ b/tests/test_singlepoint_cli.py @@ -47,8 +47,8 @@ def test_singlepoint(tmp_path): # Check atoms can read read, then delete file atoms = read_atoms(results_path) assert result.exit_code == 0 - assert atoms.get_potential_energy() is not None - assert "forces" in atoms.calc.results + assert "mace_mp_energy" in atoms.info + assert "mace_mp_forces" in atoms.arrays def test_properties(tmp_path): @@ -78,7 +78,7 @@ def test_properties(tmp_path): assert result.exit_code == 0 atoms = read(results_path_1) - assert atoms.get_potential_energy() is not None + assert "mace_mp_energy" in atoms.info result = runner.invoke( app,