Skip to content

Commit

Permalink
name ML properties in the results file by architecture, energy become…
Browse files Browse the repository at this point in the history
…s mace_energy,...
  • Loading branch information
alinelena committed Jun 5, 2024
1 parent aaaebc3 commit 0d75c65
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 24 deletions.
48 changes: 38 additions & 10 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,16 @@ def _get_potential_energy(self) -> MaybeList[float]:
MaybeList[float]
Potential energy of structure(s).
"""
tag = f"{self.architecture}_energy"
if isinstance(self.struct, list):
return [struct.get_potential_energy() for struct in self.struct]
energies = [struct.get_potential_energy() for struct in self.struct]
for energy, struct in zip(energies, self.struct):
struct.info[tag] = energy
return energies

return self.struct.get_potential_energy()
energy = self.struct.get_potential_energy()
self.struct.info[tag] = energy
return energy

def _get_forces(self) -> MaybeList[ndarray]:
"""
Expand All @@ -230,10 +236,16 @@ def _get_forces(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Forces of structure(s).
"""
tag = f"{self.architecture}_forces"
if isinstance(self.struct, list):
return [struct.get_forces() for struct in self.struct]
forces = [struct.get_forces() for struct in self.struct]
for force, struct in zip(forces, self.struct):
struct.arrays[tag] = force
return forces

return self.struct.get_forces()
force = self.struct.get_forces()
self.struct.arrays[tag] = force
return force

def _get_stress(self) -> MaybeList[ndarray]:
"""
Expand All @@ -244,13 +256,19 @@ def _get_stress(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Stress of structure(s).
"""
tag = f"{self.architecture}_stress"
if isinstance(self.struct, list):
return [struct.get_stress() for struct in self.struct]
stresses = [struct.get_stress() for struct in self.struct]
for stress, struct in zip(stresses, self.struct):
struct.info[tag] = stress
return stresses

return self.struct.get_stress()
stress = self.struct.get_stress()
self.struct.info[tag] = stress
return stress

@staticmethod
def _remove_invalid_props(
self,
struct: Atoms,
results: CalcResults = None,
properties: Collection[str] = (),
Expand All @@ -275,19 +293,21 @@ def _remove_invalid_props(
for prop in struct.calc.results
if not isfinite(struct.calc.results[prop]).all()
]

# Raise error if property was explicitly requested, otherwise remove
for prop in rm_keys:
if prop in properties:
raise ValueError(
f"'{prop}' contains non-finite values for this structure."
)
del struct.calc.results[prop]
if prop in results:
del struct.info[f"{self.architecture}_{prop}"]
del results[prop]

def _clean_results(
self, results: CalcResults = None, properties: Collection[str] = ()
self,
results: CalcResults = None,
properties: Collection[str] = (),
invalidate_calc: bool = False,
) -> None:
"""
Remove NaN and inf values from results and calc.results dictionaries.
Expand All @@ -298,14 +318,22 @@ def _clean_results(
Dictionary of calculated results. Default is {}.
properties : Collection[str]
Physical properties requested to be calculated. Default is ().
invalidate_calc : bool
Remove calculator results if True. When True Atoms object loses
its property methods and true values are in info and arrays.
Default is False.
"""
results = results if results else {}

if isinstance(self.struct, list):
for image in self.struct:
self._remove_invalid_props(image, results, properties)
if invalidate_calc:
image.calc.results = {}
else:
self._remove_invalid_props(self.struct, results, properties)
if invalidate_calc:
self.struct.calc.results = {}

def run(
self,
Expand Down
43 changes: 32 additions & 11 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_potential_energy(
raise ValueError(f"Invalid index: {idx}")
else:
assert results == pytest.approx(expected)
results_2 = single_point.struct.info["mace_energy"]
assert results_2 == pytest.approx(expected)


def test_single_point_none():
Expand All @@ -76,7 +78,7 @@ def test_single_point_clean():
results = single_point.run()
for prop in ["energy", "forces"]:
assert prop in results
assert "stress" not in results
assert "mace_stress" not in results


def test_single_point_traj():
Expand All @@ -91,6 +93,12 @@ def test_single_point_traj():
results = single_point.run("energy")
assert results["energy"][0] == pytest.approx(-76.0605725422795)
assert results["energy"][1] == pytest.approx(-74.80419118083256)
assert single_point.struct[0].info["mace_energy"] == pytest.approx(
-76.0605725422795
)
assert single_point.struct[1].info["mace_energy"] == pytest.approx(
-74.80419118083256
)


def test_single_point_write():
Expand All @@ -104,13 +112,26 @@ def test_single_point_write():
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.calc.results
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True)

atoms = read_atoms(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "mace_forces" in atoms.arrays
assert atoms.info["mace_energy"] == pytest.approx(-27.035127799332745)
assert "mace_stress" in atoms.info
assert atoms.info["mace_stress"] == pytest.approx(
[
-0.004783275999053391,
-0.004783275999053417,
-0.004783275999053412,
-2.3858882876234007e-19,
-5.02032761017409e-19,
-2.29070171362209e-19,
]
)
assert atoms.arrays["mace_forces"][0] == pytest.approx(
[4.11996826e-18, 1.79977561e-17, 1.80139537e-17]
)


def test_single_point_write_kwargs(tmp_path):
Expand All @@ -123,12 +144,12 @@ def test_single_point_write_kwargs(tmp_path):
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.calc.results
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert atoms.info["mace_energy"] is not None
assert "mace_forces" in atoms.arrays


def test_single_point_write_nan(tmp_path):
Expand All @@ -147,9 +168,9 @@ def test_single_point_write_nan(tmp_path):

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "stress" not in atoms.calc.results
assert atoms.info["mace_energy"] == pytest.approx(-14.035236305927514)
assert "mace_forces" in atoms.arrays
assert "mace_stress" not in atoms.info


def test_invalid_prop():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_singlepoint(tmp_path):
# Check atoms can read read, then delete file
atoms = read_atoms(results_path)
assert result.exit_code == 0
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "mace_mp_energy" in atoms.info
assert "mace_mp_forces" in atoms.arrays


def test_properties(tmp_path):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_properties(tmp_path):
assert result.exit_code == 0

atoms = read(results_path_1)
assert atoms.get_potential_energy() is not None
assert "mace_mp_energy" in atoms.info

result = runner.invoke(
app,
Expand Down

0 comments on commit 0d75c65

Please sign in to comment.