Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ase calculator #2

Merged
merged 11 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
python -m pip install tox

- name: build documentation
env:
# Use the CPU only version of torch when building/running the code
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
run: tox -e docs

- name: put documentation in the website
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ jobs:
python-version: "3.12"
- os: macos-14
python-version: "3.12"
- os: windows-2019
python-version: "3.12"
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -42,3 +40,5 @@ jobs:

- name: run tests
run: tox
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
77 changes: 77 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,80 @@

[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://lab-cosmo.github.io/ShiftML/latest/)
![Tests](https://img.shields.io/github/check-runs/lab-cosmo/ShiftML/main?logo=github&label=tests)

**Disclaimer: This package is still under development and should be used with caution.**

Welcome to ShitML, a python package for the prediction of chemical shieldings of organic solids and beyond.



## Usage

Use ShiftML with the atomsitic simulation environment to obtain fast estimates of chemical shieldings:

```python

from ase.build import bulk
from shiftml.ase import ShiftML

frame = bulk("C", "diamond", a=3.566)
calculator = ShiftML("ShiftML1.0")

cs_iso = calc.get_cs_iso(frame)

print(cs_iso)

```

## IMPORTANT: Install pre-instructions before PiPy release

Rascaline-torch, one of the main dependence of ShiftML, requires CXX and Rust compilers to be built from source.
Most systems come already with configured C/C++ compilers (make sure that some environment variables CC and CXX are set
and gcc can be found), but Rust typically needs to be installed manually.
For ease of use we strongly recommend to use some sort of package manager to install Rust, such as conda and a fresh environment.
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved


```bash

conda create -n shiftml python=3.10
conda activate shiftml
conda install -c conda-forge rust

```


## Installation

To install ShiftML, you can use clone this repository and install it using pip, a pipy release will follow soon:

```
pip install --extra-index-url https://download.pytorch.org/whl/cpu .
```

## The code that makes it work

This project would not have been possible without the following packages:

- Metadata and model handling: [metatensor](https://github.com/lab-cosmo/metatensor)
- Atomic descriptor engine: [rascaline](https://github.com/Luthaf/rascaline)

## Documentation

The documentation is available [here](https://lab-cosmo.github.io/ShiftML/latest/).

## Contributors

Matthias Kellner\
Yuxuan Zhang\
Ruben Rodriguez Madrid\
Guillaume Fraux
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved

## References

This package is based on the following papers:

- Chemical shifts in molecular solids by machine learning - Paruzzo et al. [[1](https://doi.org/10.1038%2Fs41467-018-06972-x)]
- A Bayesian approach to NMR crystal structure determination - Engel et al. [[2](https://doi.org/10.1039%2Fc9cp04489b)]
- A Machine Learning Model of Chemical Shifts for Chemically and\
Structurally Diverse Molecular Solids - Cordova et al. [[3](https://doi.org/10.1021/acs.jpcc.2c03854)]

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ authors = [
]

dependencies = [
"numpy",
"ase==3.22.1",
"metatensor[torch]",
"rascaline-torch@git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved
]

readme = "README.md"
license = {text = "BSD-3-Clause"}
description = "TODO"
description = "Predictions of chemical shieldings using machine learning"


[project.urls]
Expand Down
1 change: 1 addition & 0 deletions src/shiftml/ase/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .calculator import ShiftML # noqa: F401
143 changes: 143 additions & 0 deletions src/shiftml/ase/calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import urllib.request

from metatensor.torch.atomistic import ModelOutput
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator

url_resolve = {
"ShiftML1.0": "https://tinyurl.com/3xwec68f",
"ShiftML1.1": "https://tinyurl.com/53ymkhvd",
Comment on lines +15 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you re-created the exact previous models using the new infrastructure? Should these have a different version (like ShiftML1.0+metatensor) just for clarity or are you sure this will produce the same outputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will discuss this with everyone involved - but in principle I agree, that there should be a ShiftML1.0rev or something that makes it clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open an issue to remind me of that.

}

resolve_outputs = {
"ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
"ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)},
}

resolve_fitted_species = {
"ShiftML1.0": set([1, 6, 7, 8, 16]),
"ShiftML1.1": set([1, 6, 7, 8, 16]),
}


class ShiftML(MetatensorCalculator):
"""
ShiftML calculator for ASE
"""

def __init__(self, model_version, force_download=False):
"""
Initialize the ShiftML calculator

Parameters
----------
model_version : str
The version of the ShiftML model to use. Supported versions are
"ShiftML1.0".
force_download : bool, optional
If True, the model will be downloaded even if it is already in the cache.
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved
Default is False.
"""

try:
# The rascline import is necessary because
# it is required for the scripted model
import rascaline.torch
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved

print(rascaline.torch.__version__)
print("rascaline-torch is installed, importing rascaline-torch")

except ImportError:
raise ImportError(
"rascaline-torch is required for ShiftML calculators,\
please install it using\
pip install git+https://github.com/luthaf/rascaline#subdirectory\
=python/rascaline-torch"
)

try:
url = url_resolve[model_version]
self.outputs = resolve_outputs[model_version]
self.fitted_species = resolve_fitted_species[model_version]
print("Found model version in url_resolve")
print("Resolving model version to model files at url: ", url)
except KeyError:
raise ValueError(
f"Model version {model_version} is not supported.\
Supported versions are {list(url_resolve.keys())}"
)

cachedir = os.path.expanduser(
os.path.join("~", ".cache", "shiftml", str(model_version))
)
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved

# check if model is already downloaded
try:
if not os.path.exists(cachedir):
os.makedirs(cachedir)
model_file = os.path.join(cachedir, "model.pt")
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved

if os.path.exists(model_file) and force_download:
print(
f"Found {model_version} in cache, but force_download is set to True"
)
print(f"Removing {model_version} from cache and downloading it again")
os.remove(os.path.join(cachedir, "model.pt"))
download = True

else:
if os.path.exists(model_file):
print(
f"Found {model_version} in cache,\
and importing it from here: {cachedir}"
)
download = False
else:
print("Model not found in cache, downloading it")
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved
download = True

if download:
urllib.request.urlretrieve(url, os.path.join(cachedir, "model.pt"))
print(f"Downloaded {model_version} and saved to {cachedir}")

except urllib.error.URLError as e:
print(
f"Failed to download {model_version} from {url}. URL Error: {e.reason}"
)
raise e
except urllib.error.HTTPError as e:
print(
f"Failed to download {model_version} from {url}.\
HTTP Error: {e.code} - {e.reason}"
)
raise e
except Exception as e:
print(
f"An unexpected error occurred while downloading\
{model_version} from {url}: {e}"
)
raise e

super().__init__(model_file)

def get_cs_iso(self, atoms):
"""
Compute the shielding values for the given atoms object
"""

assert (
"mtt::cs_iso" in self.outputs.keys()
), "model does not support chemical shielding prediction"

if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species):
raise ValueError(
f"Model is fitted only for the following atomic numbers:\
{self.fitted_species}. The atomic numbers in the atoms object are:\
{set(atoms.get_atomic_numbers())}. Please provide an atoms object\
with only the fitted species."
)

out = self.run_model(atoms, self.outputs)
cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy()

return cs_iso
70 changes: 70 additions & 0 deletions tests/test_ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# TODO: test for rotational invariance, translation invariance,
# and permutation, as well as size extensivity
import numpy as np
import pytest
from ase.build import bulk

from shiftml.ase import ShiftML

expected_output = np.array([137.5415, 137.5415])


def test_shiftml1_regression():
"""Regression test for the ShiftML1.0 model."""

frame = bulk("C", "diamond", a=3.566)
model = ShiftML("ShiftML1.0", force_download=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why force_download=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to make sure that the local tests do not resuse a local potentially old cached file.

out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), expected_output
), "ShiftML1 failed regression test"


def test_shiftml1_rotational_invariance():
"""Rotational invariance test for the ShiftML1.0 model."""

frame = bulk("C", "diamond", a=3.566)
model = ShiftML("ShiftML1.0")
out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), expected_output
), "ShiftML1 failed regression test"

# Rotate the frame by 90 degrees about the z-axis
frame.rotate(90, "z")

out_rotated = model.get_cs_iso(frame)

assert np.allclose(
out_rotated.flatten(), expected_output
), "ShiftML1 failed rotational invariance test"


def test_shiftml1_size_extensivity_test():
"""Test ShiftML1.0 for translational invariance."""

frame = bulk("C", "diamond", a=3.566)
model = ShiftML("ShiftML1.0")
out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), expected_output
), "ShiftML1 failed regression test"

frame = frame * (2, 1, 1)
out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), np.stack([expected_output, expected_output]).flatten()
), "ShiftML1 failed size extensivity test"


def test_shftml1_fail_invalid_species():
"""Test ShiftML1.o for non-fitted species"""

frame = bulk("Si", "diamond", a=3.566)
model = ShiftML("ShiftML1.0")
with pytest.raises(ValueError):
bananenpampe marked this conversation as resolved.
Show resolved Hide resolved
model.get_cs_iso(frame)
2 changes: 0 additions & 2 deletions tests/test_dummy.py

This file was deleted.

37 changes: 37 additions & 0 deletions tests/test_regression_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
from ase.build import bulk

from shiftml.ase import ShiftML

expected_outputs = {
"ShiftML1.0": np.array([137.5415, 137.5415]),
"ShiftML1.1": np.array([163.07251, 163.07251]),
}


def test_shiftml1_regression():
"""Regression test for the ShiftML1.0 model."""

MODEL = "ShiftML1.0"

frame = bulk("C", "diamond", a=3.566)
model = ShiftML(MODEL, force_download=True)
out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), expected_outputs[MODEL]
), "ShiftML1.0 failed regression test"


def test_shiftml1_1_regression():
"""Regression test for the ShiftML1.1 model."""

MODEL = "ShiftML1.1"

frame = bulk("C", "diamond", a=3.566)
model = ShiftML(MODEL, force_download=True)
out = model.get_cs_iso(frame)

assert np.allclose(
out.flatten(), expected_outputs[MODEL]
), "ShiftML1.1 failed regression test"