Skip to content

Commit

Permalink
Merge pull request #53 from EBjerrum/handle_invalid_molecules
Browse files Browse the repository at this point in the history
Handle invalid molecules
  • Loading branch information
EBjerrum authored Oct 13, 2024
2 parents 64c0ded + c4e4c69 commit af28645
Show file tree
Hide file tree
Showing 17 changed files with 2,890 additions and 491 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ There are a collection of notebooks in the notebooks directory which demonstrate
- [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb)
- [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb)
- [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb)
- [Working with pipelines and estimators in safe inference mode for handling prediction on batches with invalid smiles or molecules](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb)

We also put a software note on ChemRxiv. [https://doi.org/10.26434/chemrxiv-2023-fzqwd](https://doi.org/10.26434/chemrxiv-2023-fzqwd)

Expand Down
1,023 changes: 1,023 additions & 0 deletions notebooks/11_safe_inference.ipynb

Large diffs are not rendered by default.

145 changes: 145 additions & 0 deletions notebooks/11_safe_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.1
# kernelspec:
# display_name: vscode
# language: python
# name: python3
# ---

# %% [markdown]
# # Safe inference mode
#
# I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode.
#
# NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail....
#
# First some imports and test SMILES and molecules.

# %%
from rdkit import Chem
from scikit_mol.conversions import SmilesToMolTransformer

#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

smiles = ["C1=CC=C(C=C1)F", "C1=CC=C(C=C1)O", "C1=CC=C(C=C1)N", "C1=CC=C(C=C1)Cl"]
smiles_with_invalid = smiles + ["N(C)(C)(C)C", "I'm not a SMILES"]

smi2mol = SmilesToMolTransformer(safe_inference_mode=True)

mols_with_invalid = smi2mol.transform(smiles_with_invalid)
mols_with_invalid

# %% [markdown]
# Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example:

# %%
[mol for mol in mols_with_invalid if mol]

# %% [markdown]
# or

# %%
mask = mols_with_invalid.astype(bool)
mols_with_invalid[mask]

# %% [markdown]
# Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility)

# %%
from scikit_mol.fingerprints import MorganFingerprintTransformer

mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True)
fps = mfp.transform(mols_with_invalid)
fps


# %% [markdown]
# However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class.

# %%
from scikit_mol.safeinference import SafeInferenceWrapper
from sklearn.linear_model import LogisticRegression
import numpy as np

regressor = LogisticRegression()
wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True)
wrapper.fit(fps, [0,1,0,1,0,1])
wrapper.predict(fps)


# %% [markdown]
#

# %% [markdown]
# The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose.
# The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules.
#
# ## Setting safe_inference_mode post-training
# As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that:

# %%
from scikit_mol.safeinference import set_safe_inference_mode
from sklearn.pipeline import Pipeline

pipe = Pipeline([
("smi2mol", SmilesToMolTransformer()),
("mfp", MorganFingerprintTransformer(radius=2, nBits=25)),
("safe_regressor", SafeInferenceWrapper(LogisticRegression()))
])

pipe.fit(smiles, [1,0,1,0])

print("Without safe inference mode:")
try:
pipe.predict(smiles_with_invalid)
except Exception as e:
print("Prediction failed with exception: ", e)
print()

set_safe_inference_mode(pipe, True)

print("With safe inference mode:")
print(pipe.predict(smiles_with_invalid))

# %% [markdown]
# We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-)

# %% [markdown]
# ## Combining safe_inference_mode with pandas output
# One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors.

# %%
mfp.set_output(transform="pandas")

mols = smi2mol.transform(smiles)

fps = mfp.transform(mols)
fps

# %% [markdown]
# Then lets see if we transform a batch with an invalid molecule:

# %%
fps = mfp.transform(mols_with_invalid)
fps

# %% [markdown]
# The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency.

# %%
mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32)
mfp_float.set_output(transform="pandas")
fps = mfp_float.transform(mols)
fps

# %% [markdown]
# I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees.
1 change: 1 addition & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ This is a collection of notebooks in the notebooks directory which demonstrates
- [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb)
- [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb)
- [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb)
- [Working with pipelines and estimators in safe inference mode](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb)
108 changes: 84 additions & 24 deletions scikit_mol/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,44 @@
import multiprocessing
from typing import Union
from rdkit import Chem
from rdkit.rdBase import BlockLogs

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin

from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME
from scikit_mol.core import (
check_transform_input,
feature_names_default_mol,
DEFAULT_MOL_COLUMN_NAME,
InvalidMol,
)

# from scikit_mol._invalid import InvalidMol

class SmilesToMolTransformer(BaseEstimator, TransformerMixin):

def __init__(self, parallel: Union[bool, int] = False):
class SmilesToMolTransformer(BaseEstimator, TransformerMixin):
"""
Transformer for converting SMILES strings to RDKit mol objects.
This transformer can be included in pipelines during development and training,
but the safe inference mode should only be enabled when deploying models for
inference in production environments.
Parameters:
-----------
parallel : Union[bool, int], default=False
If True or int > 1, enables parallel processing.
safe_inference_mode : bool, default=False
If True, enables safeguards for handling invalid data during inference.
This should only be set to True when deploying models to production.
"""

def __init__(
self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False
):
self.parallel = parallel
self.start_method = None #TODO implement handling of start_method
self.start_method = None # TODO implement handling of start_method
self.safe_inference_mode = safe_inference_mode

@feature_names_default_mol
def get_feature_names_out(self, input_features=None):
Expand All @@ -39,39 +65,73 @@ def transform(self, X_smiles_list, y=None):
Raises
------
ValueError
Raises ValueError if a SMILES string is unparsable by RDKit
Raises ValueError if a SMILES string is unparsable by RDKit and safe_inference_mode is False
"""


if not self.parallel:
return self._transform(X_smiles_list)
elif self.parallel:
n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects
n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process
n_processes = (
self.parallel if self.parallel > 1 else None
) # Pool(processes=None) autodetects
n_chunks = (
n_processes * 2
if n_processes is not None
else multiprocessing.cpu_count() * 2
) # TODO, tune the number of chunks per child process
with get_context(self.start_method).Pool(processes=n_processes) as pool:
x_chunks = np.array_split(X_smiles_list, n_chunks)
arrays = pool.map(self._transform, x_chunks) #is the helper function a safer way of handling the picklind and child process communication
arr = np.concatenate(arrays)
return arr
x_chunks = np.array_split(X_smiles_list, n_chunks)
arrays = pool.map(
self._transform, x_chunks
) # is the helper function a safer way of handling the picklind and child process communication
arr = np.concatenate(arrays)
return arr

@check_transform_input
def _transform(self, X):
X_out = []
for smiles in X:
mol = Chem.MolFromSmiles(smiles)
if mol:
X_out.append(mol)
else:
raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first')

return np.array(X_out).reshape(-1,1)
with BlockLogs():
for smiles in X:
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol:
errors = Chem.DetectChemistryProblems(mol)
if errors:
error_message = "\n".join(error.Message() for error in errors)
message = f"Invalid Molecule: {error_message}"
X_out.append(InvalidMol(str(self), message))
else:
Chem.SanitizeMol(mol)
X_out.append(mol)
else:
message = f"Invalid SMILES: {smiles}"
X_out.append(InvalidMol(str(self), message))
if not self.safe_inference_mode and not all(X_out):
fails = [x for x in X_out if not x]
raise ValueError(
f"Invalid input found: {fails}."
) # TODO with this approach we get all errors, but we do process ALL the smiles first which could be slow
return np.array(X_out).reshape(-1, 1)

@check_transform_input
def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.?
def inverse_transform(self, X_mols_list, y=None):
X_out = []

for mol in X_mols_list:
smiles = Chem.MolToSmiles(mol)
X_out.append(smiles)
if isinstance(mol, Chem.Mol):
try:
smiles = Chem.MolToSmiles(mol)
X_out.append(smiles)
except Exception as e:
X_out.append(
InvalidMol(
str(self), f"Error converting Mol to SMILES: {str(e)}"
)
)
else:
X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}"))

if not self.safe_inference_mode and not all(isinstance(x, str) for x in X_out):
fails = [x for x in X_out if not isinstance(x, str)]
raise ValueError(f"Invalid Mols found: {fails}.")

return np.array(X_out).reshape(-1,1)
return np.array(X_out).reshape(-1, 1)
Loading

0 comments on commit af28645

Please sign in to comment.