diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8c8200a..5c7f682 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,6 +25,9 @@ jobs: python -m pip install tox - name: build documentation + env: + # Use the CPU only version of torch when building/running the code + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu run: tox -e docs - name: put documentation in the website diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1847d7c..c7854ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,8 +23,6 @@ jobs: python-version: "3.12" - os: macos-14 python-version: "3.12" - - os: windows-2019 - python-version: "3.12" steps: - uses: actions/checkout@v4 with: @@ -42,3 +40,5 @@ jobs: - name: run tests run: tox + env: + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 0000000..12d49ee --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,4 @@ +Matthias Kellner +Yuxuan Zhang +Ruben Rodriguez Madrid +Guillaume Fraux \ No newline at end of file diff --git a/README.md b/README.md index 6698b47..fd60cec 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,80 @@ [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://lab-cosmo.github.io/ShiftML/latest/) ![Tests](https://img.shields.io/github/check-runs/lab-cosmo/ShiftML/main?logo=github&label=tests) + +**Disclaimer: This package is still under development and should be used with caution.** + +Welcome to ShitML, a python package for the prediction of chemical shieldings of organic solids and beyond. + + + +## Usage + +Use ShiftML with the atomsitic simulation environment to obtain fast estimates of chemical shieldings: + +```python + +from ase.build import bulk +from shiftml.ase import ShiftML + +frame = bulk("C", "diamond", a=3.566) +calculator = ShiftML("ShiftML1.0") + +cs_iso = calc.get_cs_iso(frame) + +print(cs_iso) + +``` + +## IMPORTANT: Install pre-instructions before PiPy release + +Rascaline-torch, one of the main dependence of ShiftML, requires CXX and Rust compilers to be built from source. +Most systems come already with configured C/C++ compilers (make sure that some environment variables CC and CXX are set +and gcc can be found), but Rust typically needs to be installed manually. +For ease of use we strongly recommend to use some sort of package manager to install Rust, such as conda and a fresh environment. + + +```bash + +conda create -n shiftml python=3.10 +conda activate shiftml +conda install -c conda-forge rust + +``` + + +## Installation + +To install ShiftML, you can use clone this repository and install it using pip, a pipy release will follow soon: + +``` +pip install --extra-index-url https://download.pytorch.org/whl/cpu . +``` + +## The code that makes it work + +This project would not have been possible without the following packages: + +- Metadata and model handling: [metatensor](https://github.com/lab-cosmo/metatensor) +- Atomic descriptor engine: [rascaline](https://github.com/Luthaf/rascaline) + +## Documentation + +The documentation is available [here](https://lab-cosmo.github.io/ShiftML/latest/). + +## Contributors + +Matthias Kellner\ +Yuxuan Zhang\ +Ruben Rodriguez Madrid\ +Guillaume Fraux + +## References + +This package is based on the following papers: + +- Chemical shifts in molecular solids by machine learning - Paruzzo et al. [[1](https://doi.org/10.1038%2Fs41467-018-06972-x)] +- A Bayesian approach to NMR crystal structure determination - Engel et al. [[2](https://doi.org/10.1039%2Fc9cp04489b)] +- A Machine Learning Model of Chemical Shifts for Chemically and\ +Structurally Diverse Molecular Solids - Cordova et al. [[3](https://doi.org/10.1021/acs.jpcc.2c03854)] + diff --git a/pyproject.toml b/pyproject.toml index ada7c00..84b3a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,16 @@ authors = [ ] dependencies = [ + "numpy", + "ase==3.22.1", + "metatensor[torch]", + "platformdirs", + "rascaline-torch@git+https://github.com/luthaf/rascaline@e215461#subdirectory=python/rascaline-torch", ] readme = "README.md" license = {text = "BSD-3-Clause"} -description = "TODO" +description = "Predictions of chemical shieldings using machine learning" [project.urls] diff --git a/src/shiftml/ase/__init__.py b/src/shiftml/ase/__init__.py new file mode 100644 index 0000000..daa39be --- /dev/null +++ b/src/shiftml/ase/__init__.py @@ -0,0 +1 @@ +from .calculator import ShiftML # noqa: F401 diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py new file mode 100644 index 0000000..e0eac0b --- /dev/null +++ b/src/shiftml/ase/calculator.py @@ -0,0 +1,158 @@ +import logging +import os +import urllib.request + +from metatensor.torch.atomistic import ModelOutput +from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator +from platformdirs import user_cache_path + +# For now we set the logging level to DEBUG +logformat = "%(asctime)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=logformat) + + +url_resolve = { + "ShiftML1.0": "https://tinyurl.com/3xwec68f", + "ShiftML1.1": "https://tinyurl.com/53ymkhvd", +} + +resolve_outputs = { + "ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, + "ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, +} + +resolve_fitted_species = { + "ShiftML1.0": set([1, 6, 7, 8, 16]), + "ShiftML1.1": set([1, 6, 7, 8, 16]), +} + + +class ShiftML(MetatensorCalculator): + """ + ShiftML calculator for ASE + """ + + def __init__(self, model_version, force_download=False): + """ + Initialize the ShiftML calculator + + Parameters + ---------- + model_version : str + The version of the ShiftML model to use. Supported versions are + "ShiftML1.0". + force_download : bool, optional + If True, the model will be downloaded even if it is already in the cache. + The chache-dir will be determined via the platformdirs library and should + comply with user settings such as XDG_CACHE_HOME. + Default is False. + """ + + try: + # The rascline import is necessary because + # it is required for the scripted model + import rascaline.torch + + logging.info(rascaline.torch.__version__) + logging.info("rascaline-torch is installed, importing rascaline-torch") + + assert ( + rascaline.torch.__version__ == "0.1.0.dev558" + ), "wrong rascaline-torch installed" + + except ImportError: + raise ImportError( + "rascaline-torch is required for ShiftML calculators,\ + please install it using\ + pip install git+https://github.com/luthaf/rascaline#subdirectory\ + =python/rascaline-torch" + ) + + try: + url = url_resolve[model_version] + self.outputs = resolve_outputs[model_version] + self.fitted_species = resolve_fitted_species[model_version] + logging.info("Found model version in url_resolve") + logging.info("Resolving model version to model files at url: ", url) + except KeyError: + raise ValueError( + f"Model version {model_version} is not supported.\ + Supported versions are {list(url_resolve.keys())}" + ) + + cachedir = os.path.expanduser( + os.path.join(user_cache_path(), "shiftml", str(model_version)) + ) + + # check if model is already downloaded + try: + if not os.path.exists(cachedir): + os.makedirs(cachedir) + model_file = os.path.join(cachedir, model_version + ".pt") + + if os.path.exists(model_file) and force_download: + logging.info( + f"Found {model_version} in cache, but force_download is set to True" + ) + logging.info( + f"Removing {model_version} from cache and downloading it again" + ) + os.remove(model_file) + download = True + + else: + if os.path.exists(model_file): + logging.info( + f"Found {model_version} in cache,\ + and importing it from here: {cachedir}" + ) + download = False + else: + logging.info("Model not found in cache, downloading it") + download = True + + if download: + urllib.request.urlretrieve(url, model_file) + logging.info(f"Downloaded {model_version} and saved to {cachedir}") + + except urllib.error.URLError as e: + logging.error( + f"Failed to download {model_version} from {url}. URL Error: {e.reason}" + ) + raise e + except urllib.error.HTTPError as e: + logging.error( + f"Failed to download {model_version} from {url}.\ + HTTP Error: {e.code} - {e.reason}" + ) + raise e + except Exception as e: + logging.error( + f"An unexpected error occurred while downloading\ + {model_version} from {url}: {e}" + ) + raise e + + super().__init__(model_file) + + def get_cs_iso(self, atoms): + """ + Compute the shielding values for the given atoms object + """ + + assert ( + "mtt::cs_iso" in self.outputs.keys() + ), "model does not support chemical shielding prediction" + + if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): + raise ValueError( + f"Model is fitted only for the following atomic numbers:\ + {self.fitted_species}. The atomic numbers in the atoms object are:\ + {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ + with only the fitted species." + ) + + out = self.run_model(atoms, self.outputs) + cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy() + + return cs_iso diff --git a/tests/test_ase.py b/tests/test_ase.py new file mode 100644 index 0000000..69f8ca3 --- /dev/null +++ b/tests/test_ase.py @@ -0,0 +1,75 @@ +# TODO: test for rotational invariance, translation invariance, +# and permutation, as well as size extensivity +import numpy as np +import pytest +from ase.build import bulk + +from shiftml.ase import ShiftML + +expected_output = np.array([137.5415, 137.5415]) + + +def test_shiftml1_regression(): + """Regression test for the ShiftML1.0 model.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0", force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + +def test_shiftml1_rotational_invariance(): + """Rotational invariance test for the ShiftML1.0 model.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + # Rotate the frame by 90 degrees about the z-axis + frame.rotate(90, "z") + + out_rotated = model.get_cs_iso(frame) + + assert np.allclose( + out_rotated.flatten(), expected_output + ), "ShiftML1 failed rotational invariance test" + + +def test_shiftml1_size_extensivity_test(): + """Test ShiftML1.0 for translational invariance.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + frame = frame * (2, 1, 1) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), np.stack([expected_output, expected_output]).flatten() + ), "ShiftML1 failed size extensivity test" + + +def test_shftml1_fail_invalid_species(): + """Test ShiftML1.o for non-fitted species""" + + frame = bulk("Si", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + with pytest.raises(ValueError) as exc_info: + model.get_cs_iso(frame) + + assert exc_info.type == ValueError + assert "Model is fitted only for the following atomic numbers:" in str( + exc_info.value + ) diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 95c5ab7..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_example(): - assert 1 + 1 == 2 diff --git a/tests/test_regression_pretrained.py b/tests/test_regression_pretrained.py new file mode 100644 index 0000000..eca1e79 --- /dev/null +++ b/tests/test_regression_pretrained.py @@ -0,0 +1,37 @@ +import numpy as np +from ase.build import bulk + +from shiftml.ase import ShiftML + +expected_outputs = { + "ShiftML1.0": np.array([137.5415, 137.5415]), + "ShiftML1.1": np.array([163.07251, 163.07251]), +} + + +def test_shiftml1_regression(): + """Regression test for the ShiftML1.0 model.""" + + MODEL = "ShiftML1.0" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML(MODEL, force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_outputs[MODEL] + ), "ShiftML1.0 failed regression test" + + +def test_shiftml1_1_regression(): + """Regression test for the ShiftML1.1 model.""" + + MODEL = "ShiftML1.1" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML(MODEL, force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_outputs[MODEL] + ), "ShiftML1.1 failed regression test"