From b91754f47b10b1140d6d81aa8633d6c7350529f4 Mon Sep 17 00:00:00 2001 From: sovietdevil <2558390548@qq.com> Date: Mon, 29 Jul 2024 11:23:49 +0200 Subject: [PATCH] Unify formats of different ShiftML versions, correct mistakes and simplify calculator functions. --- src/shiftml/ase/calculator.py | 97 ++++++++++++----------------------- tests/test_ase.py | 2 +- 2 files changed, 33 insertions(+), 66 deletions(-) diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index 02e3627..e091ff2 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -14,14 +14,14 @@ url_resolve = { "ShiftML1.0": "https://tinyurl.com/3xwec68f", "ShiftML1.1": "https://tinyurl.com/53ymkhvd", - "ShiftML2.0": "https://tinyurl.com/2mp8emsd", + "ShiftML2.0": "https://tinyurl.com/bdcp647w", } resolve_outputs = { "ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, "ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, "ShiftML2.0": { - "mtt::cs_iso_mean": ModelOutput(quantity="", unit="ppm", per_atom=True), + "mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True), "mtt::cs_iso_std": ModelOutput(quantity="", unit="ppm", per_atom=True), "mtt::cs_iso_ensemble": ModelOutput(quantity="", unit="ppm", per_atom=True), }, @@ -34,6 +34,16 @@ } +def is_fitted_on(atoms, fitted_species): + if not set(atoms.get_atomic_numbers()).issubset(fitted_species): + raise ValueError( + f"Model is fitted only for the following atomic numbers:\ + {fitted_species}. The atomic numbers in the atoms object are:\ + {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ + with only the fitted species." + ) + + class ShiftML(MetatensorCalculator): """ ShiftML calculator for ASE @@ -163,80 +173,37 @@ def get_cs_iso(self, atoms): """ Compute the shielding values for the given atoms object """ - if self.model_version == "ShiftML1.0" or self.model_version == "ShiftML1.1": - assert ( - "mtt::cs_iso" in self.outputs.keys() - ), "model does not support chemical shielding prediction" - - if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): - raise ValueError( - f"Model is fitted only for the following atomic numbers:\ - {self.fitted_species}. The atomic numbers in the atoms object are:\ - {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ - with only the fitted species." - ) - - out = self.run_model(atoms, self.outputs) - cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy() + assert ( + "mtt::cs_iso" in self.outputs.keys() + ), "model does not support chemical shielding prediction" - elif self.model_version == "ShiftML2.0": - assert ( - "mtt::cs_iso_mean" in self.outputs.keys() - ), "model does not support chemical shielding prediction" - - if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): - raise ValueError( - f"Model is fitted only for the following atomic numbers:\ - {self.fitted_species}. The atomic numbers in the atoms object are:\ - {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ - with only the fitted species." - ) + is_fitted_on(atoms, self.fitted_species) - out = self.run_model(atoms, self.outputs) - cs_iso = out["mtt::cs_iso_mean"].block(0).values.detach().numpy() + out = self.run_model(atoms, self.outputs) + cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy() return cs_iso def get_cs_iso_std(self, atoms): - if self.model_version == "ShiftML2.0": - assert ( - "mtt::cs_iso_std" in self.outputs.keys() - ), "model does not support chemical shielding prediction" - - if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): - raise ValueError( - f"Model is fitted only for the following atomic numbers:\ - {self.fitted_species}. The atomic numbers in the atoms object are:\ - {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ - with only the fitted species." - ) + assert ( + "mtt::cs_iso_std" in self.outputs.keys() + ), "model does not support chemical shielding prediction" - out = self.run_model(atoms, self.outputs) - cs_iso_std = out["mtt::cs_iso_std"].block(0).values.detach().numpy() - else: - raise RuntimeError("Version not supporting uncertainty quantification.") + is_fitted_on(atoms, self.fitted_species) + + out = self.run_model(atoms, self.outputs) + cs_iso_std = out["mtt::cs_iso_std"].block(0).values.detach().numpy() return cs_iso_std def get_cs_iso_ensemble(self, atoms): - if self.model_version == "ShiftML2.0": - assert ( - "mtt::cs_iso_ensemble" in self.outputs.keys() - ), "model does not support chemical shielding prediction" - - if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): - raise ValueError( - f"Model is fitted only for the following atomic numbers:\ - {self.fitted_species}. The atomic numbers in the atoms object are:\ - {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ - with only the fitted species." - ) + assert ( + "mtt::cs_iso_ensemble" in self.outputs.keys() + ), "model does not support chemical shielding prediction" - out = self.run_model(atoms, self.outputs) - cs_iso_ensemble = ( - out["mtt::cs_iso_ensemble"].block(0).values.detach().numpy() - ) - else: - raise RuntimeError("Version not supporting uncertainty quantification.") + is_fitted_on(atoms, self.fitted_species) + + out = self.run_model(atoms, self.outputs) + cs_iso_ensemble = out["mtt::cs_iso_ensemble"].block(0).values.detach().numpy() return cs_iso_ensemble diff --git a/tests/test_ase.py b/tests/test_ase.py index 482a944..06c18e7 100644 --- a/tests/test_ase.py +++ b/tests/test_ase.py @@ -231,5 +231,5 @@ def test_shiftml2_regression_mean(): ), "ShiftML2 failed regression variance test" assert np.allclose( - out_ensemble.flatten(), expected_ensemble_v2 + out_ensemble.flatten(), expected_ensemble_v2.flatten() ), "ShiftML2 failed regression ensemble test"