Skip to content

Commit

Permalink
name ML properties in the results file by architecture, energy become…
Browse files Browse the repository at this point in the history
…s mace_energy,...
  • Loading branch information
alinelena committed May 31, 2024
1 parent ee0d76d commit 404fffe
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 16 deletions.
38 changes: 31 additions & 7 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,18 @@ def _get_potential_energy(self) -> MaybeList[float]:
MaybeList[float]
Potential energy of structure(s).
"""
tag = f"{self.architecture}_energy"
if isinstance(self.struct, list):
return [struct.get_potential_energy() for struct in self.struct]
energies = [struct.get_potential_energy() for struct in self.struct]
for e, s in zip(energies, self.struct):
s.info[tag] = e
s.calc.results = {}
return energies

return self.struct.get_potential_energy()
energy = self.struct.get_potential_energy()
self.struct.info[tag] = energy
self.struct.calc.results = {}
return energy

def _get_forces(self) -> MaybeList[ndarray]:
"""
Expand All @@ -230,10 +238,18 @@ def _get_forces(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Forces of structure(s).
"""
tag = f"{self.architecture}_forces"
if isinstance(self.struct, list):
return [struct.get_forces() for struct in self.struct]
forces = [struct.get_forces() for struct in self.struct]
for force, s in zip(forces, self.struct):
s.arrays[tag] = force
s.calc.results = {}
return forces

return self.struct.get_forces()
force = self.struct.get_forces()
self.struct.arrays[tag] = force
self.struct.calc.results = {}
return force

def _get_stress(self) -> MaybeList[ndarray]:
"""
Expand All @@ -244,10 +260,18 @@ def _get_stress(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Stress of structure(s).
"""
tag = f"{self.architecture}_stress"
if isinstance(self.struct, list):
return [struct.get_stress() for struct in self.struct]

return self.struct.get_stress()
stresses = [struct.get_stress() for struct in self.struct]
for stress, s in zip(stresses, self.struct):
s.info[tag] = stress
s.calc.results = {}
return stresses

stress = self.struct.get_stress()
self.struct.info[tag] = stress
self.struct.calc.results = {}
return stress

@staticmethod
def _remove_invalid_props(
Expand Down
38 changes: 29 additions & 9 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_potential_energy(
raise ValueError(f"Invalid index: {idx}")
else:
assert results == pytest.approx(expected)
results_2 = single_point.struct.info["mace_energy"]
assert results_2 == pytest.approx(expected)


def test_single_point_none():
Expand All @@ -76,7 +78,7 @@ def test_single_point_clean():
results = single_point.run()
for prop in ["energy", "forces"]:
assert prop in results
assert "stress" not in results
assert "mace_stress" not in results


def test_single_point_traj():
Expand All @@ -91,6 +93,12 @@ def test_single_point_traj():
results = single_point.run("energy")
assert results["energy"][0] == pytest.approx(-76.0605725422795)
assert results["energy"][1] == pytest.approx(-74.80419118083256)
assert single_point.struct[0].info["mace_energy"] == pytest.approx(
-76.0605725422795
)
assert single_point.struct[1].info["mace_energy"] == pytest.approx(
-74.80419118083256
)


def test_single_point_write():
Expand All @@ -104,13 +112,25 @@ def test_single_point_write():
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.arrays
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True)

atoms = read_atoms(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.arrays
assert atoms.info["mace_energy"] == pytest.approx(-27.035127799332745)
assert "mace_stress" in atoms.info
assert atoms.info["mace_stress"] == pytest.approx(
[
-0.004783275999053391,
-0.004783275999053417,
-0.004783275999053412,
-2.3858882876234007e-19,
-5.02032761017409e-19,
-2.29070171362209e-19,
]
)
assert atoms.arrays["mace_forces"][0] == pytest.approx(
[4.11996826e-18, 1.79977561e-17, 1.80139537e-17]
)


def test_single_point_write_kwargs(tmp_path):
Expand All @@ -123,12 +143,12 @@ def test_single_point_write_kwargs(tmp_path):
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.arrays
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.arrays
assert atoms.info['mace_energy'] is not None
assert "mace_forces" in atoms.arrays


def test_single_point_write_nan(tmp_path):
Expand All @@ -147,7 +167,7 @@ def test_single_point_write_nan(tmp_path):

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert atoms.info['mace_energy'] is not None
assert "forces" in atoms.calc.results
assert "stress" not in atoms.calc.results

Expand Down

0 comments on commit 404fffe

Please sign in to comment.