Skip to content

Commit

Permalink
Add mean and standard deviation functions to the ShiftML2.0 calculator.
Browse files Browse the repository at this point in the history
  • Loading branch information
sovietdevil committed Jul 29, 2024
1 parent 4542d74 commit 5673530
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 15 deletions.
93 changes: 79 additions & 14 deletions src/shiftml/ase/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
url_resolve = {
"ShiftML1.0": "https://tinyurl.com/3xwec68f",
"ShiftML1.1": "https://tinyurl.com/53ymkhvd",
"ShiftML2.0": "https://tinyurl.com/9v8ppnru",
"ShiftML2.0": "https://tinyurl.com/2mp8emsd",
}

resolve_outputs = {
"ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML2.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML2.0": {
"mtt::cs_iso_mean": ModelOutput(quantity="", unit="ppm", per_atom=True),
"mtt::cs_iso_std": ModelOutput(quantity="", unit="ppm", per_atom=True),
"mtt::cs_iso_ensemble": ModelOutput(quantity="", unit="ppm", per_atom=True),
},
}

resolve_fitted_species = {
Expand Down Expand Up @@ -153,25 +157,86 @@ def __init__(self, model_version, force_download=False):
raise e

super().__init__(model_file)
self.model_version = model_version

def get_cs_iso(self, atoms):
"""
Compute the shielding values for the given atoms object
"""
if self.model_version == "ShiftML1.0" or self.model_version == "ShiftML1.1":
assert (
"mtt::cs_iso" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

assert (
"mtt::cs_iso" in self.outputs.keys()
), "model does not support chemical shielding prediction"
out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy()

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)
elif self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_mean" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy()
out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso_mean"].block(0).values.detach().numpy()

return cs_iso

def get_cs_iso_std(self, atoms):
if self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_std" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

out = self.run_model(atoms, self.outputs)
cs_iso_std = out["mtt::cs_iso_std"].block(0).values.detach().numpy()
else:
raise RuntimeError("Version not supporting uncertainty quantification.")

return cs_iso_std

def get_cs_iso_ensemble(self, atoms):
if self.model_version == "ShiftML2.0":
assert (
"mtt::cs_iso_ensemble" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

out = self.run_model(atoms, self.outputs)
cs_iso_ensemble = (
out["mtt::cs_iso_ensemble"].block(0).values.detach().numpy()
)
else:
raise RuntimeError("Version not supporting uncertainty quantification.")

return cs_iso_ensemble
162 changes: 161 additions & 1 deletion tests/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,144 @@
from shiftml.ase import ShiftML

expected_output = np.array([137.5415, 137.5415])
expected_ensemble_v2 = np.array(
[
[
114.8194808959961,
113.47244262695312,
117.47064208984375,
115.61190795898438,
130.88909912109375,
131.42332458496094,
120.96844482421875,
115.95867919921875,
116.98180389404297,
135.3658447265625,
120.45016479492188,
123.00967407226562,
137.23724365234375,
129.23104858398438,
131.00619506835938,
130.82601928710938,
121.90162658691406,
120.66400909423828,
109.59469604492188,
118.66798400878906,
126.18386840820312,
124.9156494140625,
120.90362548828125,
106.26658630371094,
128.32107543945312,
125.82593536376953,
121.3394775390625,
127.37902069091797,
122.92572784423828,
126.26400756835938,
112.87037658691406,
112.48919677734375,
126.00082397460938,
109.98661804199219,
110.7204818725586,
107.30191040039062,
113.85182189941406,
110.24645233154297,
133.27935791015625,
126.40534973144531,
133.42047119140625,
112.2728271484375,
126.27506256103516,
117.58969116210938,
119.17208099365234,
121.65959167480469,
115.62092590332031,
118.12762451171875,
119.478271484375,
137.32974243164062,
120.26103210449219,
118.25013732910156,
121.78120422363281,
125.66693115234375,
112.0889892578125,
115.92691802978516,
121.31621551513672,
118.76759338378906,
126.86924743652344,
129.01571655273438,
109.53144073486328,
110.71353149414062,
125.9607925415039,
108.36444091796875,
],
[
114.8194808959961,
113.47244262695312,
117.47064208984375,
115.61190795898438,
130.88909912109375,
131.42332458496094,
120.96844482421875,
115.95867919921875,
116.98180389404297,
135.3658447265625,
120.45016479492188,
123.00967407226562,
137.23724365234375,
129.23104858398438,
131.00619506835938,
130.82601928710938,
121.90162658691406,
120.66400909423828,
109.59469604492188,
118.66798400878906,
126.18386840820312,
124.9156494140625,
120.90362548828125,
106.26658630371094,
128.32107543945312,
125.82593536376953,
121.3394775390625,
127.37902069091797,
122.92572784423828,
126.26400756835938,
112.87037658691406,
112.48919677734375,
126.00082397460938,
109.98661804199219,
110.7204818725586,
107.30191040039062,
113.85182189941406,
110.24645233154297,
133.27935791015625,
126.40534973144531,
133.42047119140625,
112.2728271484375,
126.27506256103516,
117.58969116210938,
119.17208099365234,
121.65959167480469,
115.62092590332031,
118.12762451171875,
119.478271484375,
137.32974243164062,
120.26103210449219,
118.25013732910156,
121.78120422363281,
125.66693115234375,
112.0889892578125,
115.92691802978516,
121.31621551513672,
118.76759338378906,
126.86924743652344,
129.01571655273438,
109.53144073486328,
110.71353149414062,
125.9607925415039,
108.36444091796875,
],
]
)
expected_mean_v2 = np.array([120.85137, 120.85137])
expected_std_v2 = np.array([7.7993703, 7.7993703])


def test_shiftml1_regression():
Expand Down Expand Up @@ -62,7 +200,7 @@ def test_shiftml1_size_extensivity_test():


def test_shftml1_fail_invalid_species():
"""Test ShiftML1.o for non-fitted species"""
"""Test ShiftML1.0 for non-fitted species"""

frame = bulk("Si", "diamond", a=3.566)
model = ShiftML("ShiftML1.0")
Expand All @@ -73,3 +211,25 @@ def test_shftml1_fail_invalid_species():
assert "Model is fitted only for the following atomic numbers:" in str(
exc_info.value
)


def test_shiftml2_regression_mean():
"""Regression test for the ShiftML2.0 model."""

frame = bulk("C", "diamond", a=3.566)
model = ShiftML("ShiftML2.0", force_download=True)
out_mean = model.get_cs_iso(frame)
out_std = model.get_cs_iso_std(frame)
out_ensemble = model.get_cs_iso_ensemble(frame)

assert np.allclose(
out_mean.flatten(), expected_mean_v2
), "ShiftML2 failed regression mean test"

assert np.allclose(
out_std.flatten(), expected_std_v2
), "ShiftML2 failed regression variance test"

assert np.allclose(
out_ensemble.flatten(), expected_ensemble_v2
), "ShiftML2 failed regression ensemble test"

0 comments on commit 5673530

Please sign in to comment.