Skip to content

Commit

Permalink
name ML properties in the results file by architecture, energy become…
Browse files Browse the repository at this point in the history
…s mace_energy,...
  • Loading branch information
alinelena committed Jun 5, 2024
1 parent aaaebc3 commit 92233c2
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
46 changes: 36 additions & 10 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,18 @@ def _get_potential_energy(self) -> MaybeList[float]:
MaybeList[float]
Potential energy of structure(s).
"""
tag = f"{self.architecture}_energy"
if isinstance(self.struct, list):
return [struct.get_potential_energy() for struct in self.struct]
energies = [struct.get_potential_energy() for struct in self.struct]
for energy, struct in zip(energies, self.struct):
struct.info[tag] = energy
struct.calc.results = {}
return energies

return self.struct.get_potential_energy()
energy = self.struct.get_potential_energy()
self.struct.info[tag] = energy
self.struct.calc.results = {}
return energy

def _get_forces(self) -> MaybeList[ndarray]:
"""
Expand All @@ -230,10 +238,18 @@ def _get_forces(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Forces of structure(s).
"""
tag = f"{self.architecture}_forces"
if isinstance(self.struct, list):
return [struct.get_forces() for struct in self.struct]
forces = [struct.get_forces() for struct in self.struct]
for force, struct in zip(forces, self.struct):
struct.arrays[tag] = force
struct.calc.results = {}
return forces

return self.struct.get_forces()
force = self.struct.get_forces()
self.struct.arrays[tag] = force
self.struct.calc.results = {}
return force

def _get_stress(self) -> MaybeList[ndarray]:
"""
Expand All @@ -244,10 +260,18 @@ def _get_stress(self) -> MaybeList[ndarray]:
MaybeList[ndarray]
Stress of structure(s).
"""
tag = f"{self.architecture}_stress"
if isinstance(self.struct, list):
return [struct.get_stress() for struct in self.struct]
stresses = [struct.get_stress() for struct in self.struct]
for stress, struct in zip(stresses, self.struct):
struct.info[tag] = stress
struct.calc.results = {}
return stresses

return self.struct.get_stress()
stress = self.struct.get_stress()
self.struct.info[tag] = stress
self.struct.calc.results = {}
return stress

@staticmethod
def _remove_invalid_props(
Expand All @@ -272,8 +296,8 @@ def _remove_invalid_props(
# Find any properties with non-finite values
rm_keys = [
prop
for prop in struct.calc.results
if not isfinite(struct.calc.results[prop]).all()
for prop in properties
if (prop in struct.info) and (not isfinite(struct.info[prop]).all())
]

# Raise error if property was explicitly requested, otherwise remove
Expand All @@ -282,7 +306,7 @@ def _remove_invalid_props(
raise ValueError(
f"'{prop}' contains non-finite values for this structure."
)
del struct.calc.results[prop]
del struct.info[prop]
if prop in results:
del results[prop]

Expand Down Expand Up @@ -360,7 +384,9 @@ def run(
results["stress"] = self._get_stress()

# Remove meaningless values from results e.g. stress for non-periodic systems
self._clean_results(results, properties=properties)
self._clean_results(
results, properties=[f"{self.architecture}_{prop}" for prop in properties]
)

if self.logger:
self.logger.info("Single point calculation complete")
Expand Down
43 changes: 32 additions & 11 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_potential_energy(
raise ValueError(f"Invalid index: {idx}")
else:
assert results == pytest.approx(expected)
results_2 = single_point.struct.info["mace_energy"]
assert results_2 == pytest.approx(expected)


def test_single_point_none():
Expand All @@ -76,7 +78,7 @@ def test_single_point_clean():
results = single_point.run()
for prop in ["energy", "forces"]:
assert prop in results
assert "stress" not in results
assert "mace_stress" not in results


def test_single_point_traj():
Expand All @@ -91,6 +93,12 @@ def test_single_point_traj():
results = single_point.run("energy")
assert results["energy"][0] == pytest.approx(-76.0605725422795)
assert results["energy"][1] == pytest.approx(-74.80419118083256)
assert single_point.struct[0].info["mace_energy"] == pytest.approx(
-76.0605725422795
)
assert single_point.struct[1].info["mace_energy"] == pytest.approx(
-74.80419118083256
)


def test_single_point_write():
Expand All @@ -104,13 +112,26 @@ def test_single_point_write():
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.calc.results
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True)

atoms = read_atoms(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "mace_forces" in atoms.arrays
assert atoms.info["mace_energy"] == pytest.approx(-27.035127799332745)
assert "mace_stress" in atoms.info
assert atoms.info["mace_stress"] == pytest.approx(
[
-0.004783275999053391,
-0.004783275999053417,
-0.004783275999053412,
-2.3858882876234007e-19,
-5.02032761017409e-19,
-2.29070171362209e-19,
]
)
assert atoms.arrays["mace_forces"][0] == pytest.approx(
[4.11996826e-18, 1.79977561e-17, 1.80139537e-17]
)


def test_single_point_write_kwargs(tmp_path):
Expand All @@ -123,12 +144,12 @@ def test_single_point_write_kwargs(tmp_path):
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.calc.results
assert "mace_forces" not in single_point.struct.arrays

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert atoms.info["mace_energy"] is not None
assert "mace_forces" in atoms.arrays


def test_single_point_write_nan(tmp_path):
Expand All @@ -147,9 +168,9 @@ def test_single_point_write_nan(tmp_path):

single_point.run(write_results=True, write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "stress" not in atoms.calc.results
assert atoms.info["mace_energy"] is not None
assert "mace_forces" in atoms.arrays
assert "mace_stress" not in atoms.info


def test_invalid_prop():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_singlepoint(tmp_path):
# Check atoms can read read, then delete file
atoms = read_atoms(results_path)
assert result.exit_code == 0
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "mace_mp_energy" in atoms.info
assert "mace_mp_forces" in atoms.arrays


def test_properties(tmp_path):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_properties(tmp_path):
assert result.exit_code == 0

atoms = read(results_path_1)
assert atoms.get_potential_energy() is not None
assert "mace_mp_energy" in atoms.info

result = runner.invoke(
app,
Expand Down

0 comments on commit 92233c2

Please sign in to comment.