From 08b6b555bef5279311ce1e56bf2d81ed2c9f46da Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:09:13 +0200 Subject: [PATCH 01/11] Adding ShiftML1.0 --- README.md | 72 +++++++++++++++++++ pyproject.toml | 6 +- src/shiftml/ase/__init__.py | 1 + src/shiftml/ase/calculator.py | 128 ++++++++++++++++++++++++++++++++++ tests/test_ase.py | 70 +++++++++++++++++++ 5 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 src/shiftml/ase/__init__.py create mode 100644 src/shiftml/ase/calculator.py create mode 100644 tests/test_ase.py diff --git a/README.md b/README.md index 6698b47..b8ad5b9 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,75 @@ [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://lab-cosmo.github.io/ShiftML/latest/) ![Tests](https://img.shields.io/github/check-runs/lab-cosmo/ShiftML/main?logo=github&label=tests) + +**Disclaimer: This package is still under development and should be used with caution.** + +Welcome to ShitML, a python package for the prediction of chemical shieldings of organic solids and beyond. + + + +## Usage + +Use ShiftML with the atomsitic simulation environment to obtain fast estimates of chemical shieldings: + +```python + +from ase.build import bulk +from shiftml.ase import ShiftML + +frame = bulk("C", "diamond", a=3.566) +calculator = ShiftML("ShiftML1.0") + +cs_iso = calc.get_cs_iso(frame) + +print(cs_iso) + +``` + +## IMPORTANT: Install pre-instructions before PiPy release + +Rascaline-torch, one of the main dependence of ShiftML, requires CXX and Rust compilers to be built from source. +Most systems come already with configured C/C++ compilers (make sure that some environment variables CC and CXX are set +and gcc can be found), but Rust typically needs to be installed manually. +For ease of use we strongly recommend to use some sort of package manager to install Rust, such as conda and a fresh environment. + + +```bash + +conda create -n shiftml python=3.10 +conda activate shiftml +conda install -c conda-forge rust + +``` + + +## Installation + +To install ShiftML, you can use clone this repository and install it using pip, a pipy release will follow soon: + +``` +pip install . +``` + +## The code that makes it work + +This project would not have been possible without the following packages: + +- [metatensor](https://github.com/lab-cosmo/metatensor) +- [rascaline](https://github.com/Luthaf/rascaline) + +## Documentation + +The documentation is available [here](https://lab-cosmo.github.io/ShiftML/latest/). + +## Contributors + +Matthias Kellner\ +Yuxuan Zhang\ +Ruben Rodriguez Madrid\ +Guillaume Fraux + +## References + +This package is based on the following publications: + diff --git a/pyproject.toml b/pyproject.toml index ada7c00..fb8923b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,15 @@ authors = [ ] dependencies = [ + "numpy", + "ase==3.22.1", + "metatensor[torch]", + "rascaline-torch@git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", ] readme = "README.md" license = {text = "BSD-3-Clause"} -description = "TODO" +description = "Predictions of chemical shieldings using machine learning" [project.urls] diff --git a/src/shiftml/ase/__init__.py b/src/shiftml/ase/__init__.py new file mode 100644 index 0000000..daa39be --- /dev/null +++ b/src/shiftml/ase/__init__.py @@ -0,0 +1 @@ +from .calculator import ShiftML # noqa: F401 diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py new file mode 100644 index 0000000..e892e5e --- /dev/null +++ b/src/shiftml/ase/calculator.py @@ -0,0 +1,128 @@ +import os +import urllib.request + +from metatensor.torch.atomistic import ModelOutput +from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator + +url_resolve = { + "ShiftML1.0": "https://tinyurl.com/3xwec68f", +} + +resolve_outputs = { + "ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, +} + +resolve_fitted_species = { + "ShiftML1.0": set([1, 6, 7, 8, 16]), +} + + +class ShiftML(MetatensorCalculator): + """ + ShiftML calculator for ASE + """ + + def __init__(self, model_version, force_download=False): + + try: + # The rascline import is necessary because + # it is required for the scripted model + import rascaline.torch + + print(rascaline.torch.__version__) + print("rascaline-torch is installed, importing rascaline-torch") + + except ImportError: + raise ImportError( + "rascaline-torch is required for ShiftML calculators,\ + please install it using\ + pip install git+https://github.com/luthaf/rascaline#subdirectory\ + =python/rascaline-torch" + ) + + try: + url = url_resolve[model_version] + self.outputs = resolve_outputs[model_version] + self.fitted_species = resolve_fitted_species[model_version] + print("Found model version in url_resolve") + print("Resolving model version to model files at url: ", url) + except KeyError: + raise ValueError( + f"Model version {model_version} is not supported.\ + Supported versions are {list(url_resolve.keys())}" + ) + + cachedir = os.path.expanduser( + os.path.join("~", ".cache", "shiftml", str(model_version)) + ) + + # check if model is already downloaded + try: + if not os.path.exists(cachedir): + os.makedirs(cachedir) + model_file = os.path.join(cachedir, "model.pt") + + if os.path.exists(model_file) and force_download: + print( + f"Found {model_version} in cache, but force_download is set to True" + ) + print(f"Removing {model_version} from cache and downloading it again") + os.remove(os.path.join(cachedir, "model.pt")) + download = True + + else: + if os.path.exists(model_file): + print( + f"Found {model_version} in cache,\ + and importing it from here: {cachedir}" + ) + download = False + else: + print("Model not found in cache, downloading it") + download = True + + if download: + urllib.request.urlretrieve(url, os.path.join(cachedir, "model.pt")) + print(f"Downloaded {model_version} and saved to {cachedir}") + + except urllib.error.URLError as e: + print( + f"Failed to download {model_version} from {url}. URL Error: {e.reason}" + ) + raise e + except urllib.error.HTTPError as e: + print( + f"Failed to download {model_version} from {url}.\ + HTTP Error: {e.code} - {e.reason}" + ) + raise e + except Exception as e: + print( + f"An unexpected error occurred while downloading\ + {model_version} from {url}: {e}" + ) + raise e + + super().__init__(model_file) + + def get_cs_iso(self, atoms): + """ + Compute the shielding values for the given atoms object + """ + + assert ( + "mtt::cs_iso" in self.outputs.keys() + ), "model does not support chemical shielding prediction" + + if not set(atoms.get_atomic_numbers()).issubset(self.fitted_species): + raise ValueError( + f"Model is fitted only for the following atomic numbers:\ + {self.fitted_species}. The atomic numbers in the atoms object are:\ + {set(atoms.get_atomic_numbers())}. Please provide an atoms object\ + with only the fitted species." + ) + + out = self.run_model(atoms, self.outputs) + cs_iso = out["mtt::cs_iso"].block(0).values.detach().numpy() + + return cs_iso diff --git a/tests/test_ase.py b/tests/test_ase.py new file mode 100644 index 0000000..4910b18 --- /dev/null +++ b/tests/test_ase.py @@ -0,0 +1,70 @@ +# TODO: test for rotational invariance, translation invariance, +# and permutation, as well as size extensivity +import numpy as np +import pytest +from ase.build import bulk + +from shiftml.ase import ShiftML + +expected_output = np.array([137.5415, 137.5415]) + + +def test_shiftml1_regression(): + """Regression test for the ShiftML1.0 model.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0", force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + +def test_shiftml1_rotational_invariance(): + """Rotational invariance test for the ShiftML1.0 model.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + # Rotate the frame by 90 degrees about the z-axis + frame.rotate(90, "z") + + out_rotated = model.get_cs_iso(frame) + + assert np.allclose( + out_rotated.flatten(), expected_output + ), "ShiftML1 failed rotational invariance test" + + +def test_shiftml1_size_extensivity_test(): + """Test ShiftML1.0 for translational invariance.""" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_output + ), "ShiftML1 failed regression test" + + frame = frame * (2, 1, 1) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), np.stack([expected_output, expected_output]).flatten() + ), "ShiftML1 failed size extensivity test" + + +def test_shftml1_fail_invalid_species(): + """Test ShiftML1.o for non-fitted species""" + + frame = bulk("Si", "diamond", a=3.566) + model = ShiftML("ShiftML1.0") + with pytest.raises(ValueError): + model.get_cs_iso(frame) From 8a3b93b41de967bf4a9a62755baea0a254097750 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:14:14 +0200 Subject: [PATCH 02/11] remove dummy tests --- tests/test_dummy.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 tests/test_dummy.py diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 95c5ab7..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_example(): - assert 1 + 1 == 2 From f70327472db812e06c5d1ea29a741a62a359392b Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:37:28 +0200 Subject: [PATCH 03/11] disable windows test --- .github/workflows/tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1847d7c..938f8f3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,8 +23,6 @@ jobs: python-version: "3.12" - os: macos-14 python-version: "3.12" - - os: windows-2019 - python-version: "3.12" steps: - uses: actions/checkout@v4 with: From 8a50ec195f3cb57197e8ffdfc0a04a9789c332a5 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:49:40 +0200 Subject: [PATCH 04/11] add extra-url --- .github/workflows/tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 938f8f3..bd8aa24 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,3 +40,6 @@ jobs: - name: run tests run: tox + env: + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu + METATENSOR_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} From 6f9dbb084fd8dfa33ce12c5f7804b576d257c671 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:55:48 +0200 Subject: [PATCH 05/11] reanable windows-test --- .github/workflows/tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bd8aa24..ba439d1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,8 @@ jobs: python-version: "3.12" - os: macos-14 python-version: "3.12" + - os: windows-2019 + python-version: "3.12" steps: - uses: actions/checkout@v4 with: @@ -42,4 +44,3 @@ jobs: run: tox env: PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu - METATENSOR_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} From ff5e8f182861a30d8825a7c629dd9368cb626a2d Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 16:59:42 +0200 Subject: [PATCH 06/11] disable windows test again --- .github/workflows/tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ba439d1..c7854ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,8 +23,6 @@ jobs: python-version: "3.12" - os: macos-14 python-version: "3.12" - - os: windows-2019 - python-version: "3.12" steps: - uses: actions/checkout@v4 with: From b777c19217f5ce41104e9606551bffe4c05e4140 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 17:13:20 +0200 Subject: [PATCH 07/11] also fixes doc runjob --- .github/workflows/docs.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8c8200a..5c7f682 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,6 +25,9 @@ jobs: python -m pip install tox - name: build documentation + env: + # Use the CPU only version of torch when building/running the code + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu run: tox -e docs - name: put documentation in the website From a1ef384102e32254afb01a03b4d2550ca685b255 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 22:10:25 +0200 Subject: [PATCH 08/11] adds docstring --- src/shiftml/ase/calculator.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index e892e5e..ad0ce77 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -23,6 +23,23 @@ class ShiftML(MetatensorCalculator): """ def __init__(self, model_version, force_download=False): + #add .rst doc string with code snippet + + """ + Initialize the ShiftML calculator + + Parameters + ---------- + model_version : str + The version of the ShiftML model to use. Supported versions are + "ShiftML1.0". + + force_download : bool, optional + If True, the model will be downloaded even if it is already in the cache. + Default is False. + + """ + try: # The rascline import is necessary because From 1ef39d08f92e36629624fea0cda270ed011ca7f0 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 19 Jul 2024 22:27:32 +0200 Subject: [PATCH 09/11] lint --- src/shiftml/ase/calculator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index ad0ce77..130cdf3 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -23,8 +23,6 @@ class ShiftML(MetatensorCalculator): """ def __init__(self, model_version, force_download=False): - #add .rst doc string with code snippet - """ Initialize the ShiftML calculator @@ -33,14 +31,11 @@ def __init__(self, model_version, force_download=False): model_version : str The version of the ShiftML model to use. Supported versions are "ShiftML1.0". - force_download : bool, optional If True, the model will be downloaded even if it is already in the cache. Default is False. - """ - try: # The rascline import is necessary because # it is required for the scripted model From 201056e87b38dc9755b7ed2480651c045c8a0013 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Mon, 22 Jul 2024 11:12:30 +0200 Subject: [PATCH 10/11] adds regressiontest, and ShiftML1.1 --- README.md | 13 ++++++---- src/shiftml/ase/calculator.py | 3 +++ tests/test_regression_pretrained.py | 37 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 tests/test_regression_pretrained.py diff --git a/README.md b/README.md index b8ad5b9..fd60cec 100644 --- a/README.md +++ b/README.md @@ -49,15 +49,15 @@ conda install -c conda-forge rust To install ShiftML, you can use clone this repository and install it using pip, a pipy release will follow soon: ``` -pip install . +pip install --extra-index-url https://download.pytorch.org/whl/cpu . ``` ## The code that makes it work This project would not have been possible without the following packages: -- [metatensor](https://github.com/lab-cosmo/metatensor) -- [rascaline](https://github.com/Luthaf/rascaline) +- Metadata and model handling: [metatensor](https://github.com/lab-cosmo/metatensor) +- Atomic descriptor engine: [rascaline](https://github.com/Luthaf/rascaline) ## Documentation @@ -72,5 +72,10 @@ Guillaume Fraux ## References -This package is based on the following publications: +This package is based on the following papers: + +- Chemical shifts in molecular solids by machine learning - Paruzzo et al. [[1](https://doi.org/10.1038%2Fs41467-018-06972-x)] +- A Bayesian approach to NMR crystal structure determination - Engel et al. [[2](https://doi.org/10.1039%2Fc9cp04489b)] +- A Machine Learning Model of Chemical Shifts for Chemically and\ +Structurally Diverse Molecular Solids - Cordova et al. [[3](https://doi.org/10.1021/acs.jpcc.2c03854)] diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index 130cdf3..fbd070b 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -6,14 +6,17 @@ url_resolve = { "ShiftML1.0": "https://tinyurl.com/3xwec68f", + "ShiftML1.1": "https://tinyurl.com/53ymkhvd", } resolve_outputs = { "ShiftML1.0": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, + "ShiftML1.1": {"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True)}, } resolve_fitted_species = { "ShiftML1.0": set([1, 6, 7, 8, 16]), + "ShiftML1.1": set([1, 6, 7, 8, 16]), } diff --git a/tests/test_regression_pretrained.py b/tests/test_regression_pretrained.py new file mode 100644 index 0000000..eca1e79 --- /dev/null +++ b/tests/test_regression_pretrained.py @@ -0,0 +1,37 @@ +import numpy as np +from ase.build import bulk + +from shiftml.ase import ShiftML + +expected_outputs = { + "ShiftML1.0": np.array([137.5415, 137.5415]), + "ShiftML1.1": np.array([163.07251, 163.07251]), +} + + +def test_shiftml1_regression(): + """Regression test for the ShiftML1.0 model.""" + + MODEL = "ShiftML1.0" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML(MODEL, force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_outputs[MODEL] + ), "ShiftML1.0 failed regression test" + + +def test_shiftml1_1_regression(): + """Regression test for the ShiftML1.1 model.""" + + MODEL = "ShiftML1.1" + + frame = bulk("C", "diamond", a=3.566) + model = ShiftML(MODEL, force_download=True) + out = model.get_cs_iso(frame) + + assert np.allclose( + out.flatten(), expected_outputs[MODEL] + ), "ShiftML1.1 failed regression test" From 3284dc918ad7436da464e6e63bfabf9d1384a786 Mon Sep 17 00:00:00 2001 From: Matthias Kellner Date: Fri, 26 Jul 2024 12:11:48 +0200 Subject: [PATCH 11/11] Works in suggested changes: -uses logging class -uses platformdirs for cache path resolution -asserts correct rascaline.torch version -checks for correct rascaline.torch -model files named after model version -tests for correct assertion in invalid species test --- CONTRIBUTORS | 4 +++ pyproject.toml | 3 ++- src/shiftml/ase/calculator.py | 47 +++++++++++++++++++++++------------ tests/test_ase.py | 7 +++++- 4 files changed, 43 insertions(+), 18 deletions(-) create mode 100644 CONTRIBUTORS diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 0000000..12d49ee --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,4 @@ +Matthias Kellner +Yuxuan Zhang +Ruben Rodriguez Madrid +Guillaume Fraux \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fb8923b..84b3a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ dependencies = [ "numpy", "ase==3.22.1", "metatensor[torch]", - "rascaline-torch@git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", + "platformdirs", + "rascaline-torch@git+https://github.com/luthaf/rascaline@e215461#subdirectory=python/rascaline-torch", ] readme = "README.md" diff --git a/src/shiftml/ase/calculator.py b/src/shiftml/ase/calculator.py index fbd070b..e0eac0b 100644 --- a/src/shiftml/ase/calculator.py +++ b/src/shiftml/ase/calculator.py @@ -1,8 +1,15 @@ +import logging import os import urllib.request from metatensor.torch.atomistic import ModelOutput from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator +from platformdirs import user_cache_path + +# For now we set the logging level to DEBUG +logformat = "%(asctime)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=logformat) + url_resolve = { "ShiftML1.0": "https://tinyurl.com/3xwec68f", @@ -36,6 +43,8 @@ def __init__(self, model_version, force_download=False): "ShiftML1.0". force_download : bool, optional If True, the model will be downloaded even if it is already in the cache. + The chache-dir will be determined via the platformdirs library and should + comply with user settings such as XDG_CACHE_HOME. Default is False. """ @@ -44,8 +53,12 @@ def __init__(self, model_version, force_download=False): # it is required for the scripted model import rascaline.torch - print(rascaline.torch.__version__) - print("rascaline-torch is installed, importing rascaline-torch") + logging.info(rascaline.torch.__version__) + logging.info("rascaline-torch is installed, importing rascaline-torch") + + assert ( + rascaline.torch.__version__ == "0.1.0.dev558" + ), "wrong rascaline-torch installed" except ImportError: raise ImportError( @@ -59,8 +72,8 @@ def __init__(self, model_version, force_download=False): url = url_resolve[model_version] self.outputs = resolve_outputs[model_version] self.fitted_species = resolve_fitted_species[model_version] - print("Found model version in url_resolve") - print("Resolving model version to model files at url: ", url) + logging.info("Found model version in url_resolve") + logging.info("Resolving model version to model files at url: ", url) except KeyError: raise ValueError( f"Model version {model_version} is not supported.\ @@ -68,51 +81,53 @@ def __init__(self, model_version, force_download=False): ) cachedir = os.path.expanduser( - os.path.join("~", ".cache", "shiftml", str(model_version)) + os.path.join(user_cache_path(), "shiftml", str(model_version)) ) # check if model is already downloaded try: if not os.path.exists(cachedir): os.makedirs(cachedir) - model_file = os.path.join(cachedir, "model.pt") + model_file = os.path.join(cachedir, model_version + ".pt") if os.path.exists(model_file) and force_download: - print( + logging.info( f"Found {model_version} in cache, but force_download is set to True" ) - print(f"Removing {model_version} from cache and downloading it again") - os.remove(os.path.join(cachedir, "model.pt")) + logging.info( + f"Removing {model_version} from cache and downloading it again" + ) + os.remove(model_file) download = True else: if os.path.exists(model_file): - print( + logging.info( f"Found {model_version} in cache,\ and importing it from here: {cachedir}" ) download = False else: - print("Model not found in cache, downloading it") + logging.info("Model not found in cache, downloading it") download = True if download: - urllib.request.urlretrieve(url, os.path.join(cachedir, "model.pt")) - print(f"Downloaded {model_version} and saved to {cachedir}") + urllib.request.urlretrieve(url, model_file) + logging.info(f"Downloaded {model_version} and saved to {cachedir}") except urllib.error.URLError as e: - print( + logging.error( f"Failed to download {model_version} from {url}. URL Error: {e.reason}" ) raise e except urllib.error.HTTPError as e: - print( + logging.error( f"Failed to download {model_version} from {url}.\ HTTP Error: {e.code} - {e.reason}" ) raise e except Exception as e: - print( + logging.error( f"An unexpected error occurred while downloading\ {model_version} from {url}: {e}" ) diff --git a/tests/test_ase.py b/tests/test_ase.py index 4910b18..69f8ca3 100644 --- a/tests/test_ase.py +++ b/tests/test_ase.py @@ -66,5 +66,10 @@ def test_shftml1_fail_invalid_species(): frame = bulk("Si", "diamond", a=3.566) model = ShiftML("ShiftML1.0") - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: model.get_cs_iso(frame) + + assert exc_info.type == ValueError + assert "Model is fitted only for the following atomic numbers:" in str( + exc_info.value + )