Skip to content

Commit

Permalink
Unify formats of different ShiftML versions, correct mistakes and sim…
Browse files Browse the repository at this point in the history
…plify calculator functions.
  • Loading branch information
sovietdevil committed Jul 29, 2024
1 parent 5673530 commit b91754f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 66 deletions.
97 changes: 32 additions & 65 deletions src/shiftml/ase/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
url_resolve = {
"ShiftML1.0": "https://tinyurl.com/3xwec68f",
"ShiftML1.1": "https://tinyurl.com/53ymkhvd",
"ShiftML2.0": "https://tinyurl.com/2mp8emsd",
"ShiftML2.0": "https://tinyurl.com/bdcp647w",
}

resolve_outputs = {
"ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML2.0": {
"mtt::cs_iso_mean": ModelOutput(quantity="", unit="ppm", per_atom=True),
"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True),
"mtt::cs_iso_std": ModelOutput(quantity="", unit="ppm", per_atom=True),
"mtt::cs_iso_ensemble": ModelOutput(quantity="", unit="ppm", per_atom=True),
},
Expand All @@ -34,6 +34,16 @@
}


def is_fitted_on(atoms, fitted_species):
if not set(atoms.get_atomic_numbers()).issubset(fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)


class ShiftML(MetatensorCalculator):
"""
ShiftML calculator for ASE
Expand Down Expand Up @@ -163,80 +173,37 @@ def get_cs_iso(self, atoms):
"""
Compute the shielding values for the given atoms object
"""
if self.model_version == "ShiftML1.0" or self.model_version == "ShiftML1.1":
assert (
"mtt::cs_iso" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy()
assert (
"mtt::cs_iso" in self.outputs.keys()
), "model does not support chemical shielding prediction"

elif self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_mean" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)
is_fitted_on(atoms, self.fitted_species)

out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso_mean"].block(0).values.detach().numpy()
out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy()

return cs_iso

def get_cs_iso_std(self, atoms):
if self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_std" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)
assert (
"mtt::cs_iso_std" in self.outputs.keys()
), "model does not support chemical shielding prediction"

out = self.run_model(atoms, self.outputs)
cs_iso_std = out["mtt::cs_iso_std"].block(0).values.detach().numpy()
else:
raise RuntimeError("Version not supporting uncertainty quantification.")
is_fitted_on(atoms, self.fitted_species)

out = self.run_model(atoms, self.outputs)
cs_iso_std = out["mtt::cs_iso_std"].block(0).values.detach().numpy()

return cs_iso_std

def get_cs_iso_ensemble(self, atoms):
if self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_ensemble" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)
assert (
"mtt::cs_iso_ensemble" in self.outputs.keys()
), "model does not support chemical shielding prediction"

out = self.run_model(atoms, self.outputs)
cs_iso_ensemble = (
out["mtt::cs_iso_ensemble"].block(0).values.detach().numpy()
)
else:
raise RuntimeError("Version not supporting uncertainty quantification.")
is_fitted_on(atoms, self.fitted_species)

out = self.run_model(atoms, self.outputs)
cs_iso_ensemble = out["mtt::cs_iso_ensemble"].block(0).values.detach().numpy()

return cs_iso_ensemble
2 changes: 1 addition & 1 deletion tests/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,5 @@ def test_shiftml2_regression_mean():
), "ShiftML2 failed regression variance test"

assert np.allclose(
out_ensemble.flatten(), expected_ensemble_v2
out_ensemble.flatten(), expected_ensemble_v2.flatten()
), "ShiftML2 failed regression ensemble test"

0 comments on commit b91754f

Please sign in to comment.