-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from datamol-io/doc
Splito tutorials and docs
- Loading branch information
Showing
9 changed files
with
1,836 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,36 @@ | ||
# Getting Started | ||
# Overview | ||
|
||
Welcome to the splito documentation! | ||
Splito is a python library designed for aiding in drug discovery by providing powerful methods for parsing and splitting datasets. It enables researchers and chemists to efficiently process data for their ML projects. | ||
|
||
Splito is part of the Datamol ecosystem: <https://datamol.io>. | ||
|
||
## Installation | ||
|
||
Use conda: | ||
|
||
```bash | ||
mamba install -c conda-forge splito | ||
``` | ||
|
||
_**Note:** We highly recommend using a [Conda Python distribution](https://github.com/conda-forge/miniforge) to install Datamol. The package is also pip installable if you need it: `pip install splito`._ | ||
|
||
## Quick API Tour | ||
|
||
```python | ||
import datamol as dm | ||
from splito import ScaffoldSplit | ||
|
||
|
||
# Load some data | ||
data = dm.data.chembl_drugs() | ||
|
||
# Initialize a splitter | ||
splitter = ScaffoldSplit(smiles=data["smiles"].tolist(), n_jobs=-1, test_size=0.2, random_state=111) | ||
|
||
# Generate indices for training set and test set | ||
train_idx, test_idx = next(splitter.split(X=data.smiles.values)) | ||
``` | ||
|
||
## Tutorials | ||
|
||
Check out the [tutorials](tutorials/The_Basics.ipynb) to get started. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5711a447-d6fb-4b41-be0e-fece8b0733eb", | ||
"metadata": {}, | ||
"source": [ | ||
"As model selection is often argued to improve generalization, we investigate what molecular splitting strategy mimics the deployment distribution the best. The investigation measures the representativeness of various candidate splitting methods.\n", | ||
"\n", | ||
"1. **Compute the distance of each molecule in the deployment set(s) to the training set.** This step gives the “deployment-to-train” distribution which is the target distance distribution that should be mimicked during model selection to better generalize during deployment. If the final model will be retrained on the full-dataset before deployment, the distances must be computed w.r.t the full dataset instead of just the training partition.\n", | ||
"2. **Characterize each splitting method by splitting the dataset into a train and test sets.** Then, compute the distance of each test sample to the training set to get the “test-to-train” distribution. For small datasets, this step should be repeated with multiple seeds to get more reliable estimates of the test-to-train distribution before doing the final split that will be used for training.\n", | ||
"3. **Score the different splitting methods by measuring the distance between their test-to-train distribution and the deployment-to-train distance distribution.** Then, select the splitting method that has the lowest distance for model selection. Here, we use the Jenssen-Shannon distance between the distributions.\n", | ||
"\n", | ||
"This protocol is implemented in the MOODSplitter. See an example of how to use the it below:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "1418374e-1889-43c5-b271-09d80cdea64f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import datamol as dm\n", | ||
"\n", | ||
"from sklearn.model_selection import ShuffleSplit\n", | ||
"\n", | ||
"import splito" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "9ce88ac7-9547-488a-9814-a8fc0fd944ed", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the training dataset\n", | ||
"dataset = dm.data.solubility()\n", | ||
"dataset_feat = [dm.to_fp(mol) for mol in dataset.mol]\n", | ||
"\n", | ||
"# Load the deployment set\n", | ||
"# Alternatively, you can also load an array of deployment-to-dataset distance\n", | ||
"deployment_feat = [dm.to_fp(mol) for mol in dm.data.chembl_drugs()[\"smiles\"]]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "51eb16eb-59b2-4c48-b71b-4228546a41e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define the candidate splitters\n", | ||
"# Since we use the scikit-learn interface, this can also be sklearn Splitters\n", | ||
"splitters = {\n", | ||
" \"Random\": ShuffleSplit(),\n", | ||
" \"Scaffold\": splito.ScaffoldSplit(dataset.mol.values),\n", | ||
" \"Perimeter\": splito.PerimeterSplit(),\n", | ||
" \"MaxDissimilarity\": splito.MaxDissimilaritySplit(),\n", | ||
"}\n", | ||
"\n", | ||
"splitter = splito.MOODSplitter(splitters)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "81a6b829-196e-4506-918b-e478e1bad419", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[32m2023-09-22 08:57:15.795\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msplito._mood_split\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m308\u001b[0m - \u001b[1mRanked all different splitting methods:\n", | ||
" split representativeness best rank\n", | ||
"0 Random 0.375938 False 4.0\n", | ||
"1 Scaffold 0.492793 False 3.0\n", | ||
"2 Perimeter 0.526232 False 2.0\n", | ||
"3 MaxDissimilarity 0.552740 True 1.0\u001b[0m\n", | ||
"\u001b[32m2023-09-22 08:57:15.795\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msplito._mood_split\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m309\u001b[0m - \u001b[1mSelected MaxDissimilarity as the most representative splitting method\u001b[0m\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>split</th>\n", | ||
" <th>representativeness</th>\n", | ||
" <th>best</th>\n", | ||
" <th>rank</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>Random</td>\n", | ||
" <td>0.375938</td>\n", | ||
" <td>False</td>\n", | ||
" <td>4.0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>Scaffold</td>\n", | ||
" <td>0.492793</td>\n", | ||
" <td>False</td>\n", | ||
" <td>3.0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>Perimeter</td>\n", | ||
" <td>0.526232</td>\n", | ||
" <td>False</td>\n", | ||
" <td>2.0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>3</th>\n", | ||
" <td>MaxDissimilarity</td>\n", | ||
" <td>0.552740</td>\n", | ||
" <td>True</td>\n", | ||
" <td>1.0</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" split representativeness best rank\n", | ||
"0 Random 0.375938 False 4.0\n", | ||
"1 Scaffold 0.492793 False 3.0\n", | ||
"2 Perimeter 0.526232 False 2.0\n", | ||
"3 MaxDissimilarity 0.552740 True 1.0" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# get the rank of the splitting methods with the givent deployment set\n", | ||
"splitter.fit(X=np.stack(dataset_feat), X_deployment=np.stack(deployment_feat))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "73b49b85-fd20-4e52-b257-1f3a207cf0f3", | ||
"metadata": {}, | ||
"source": [ | ||
"With the given deployment, the best splitting method to ensure the generalization is the `PerimeterSplit`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8d24ad86", | ||
"metadata": {}, | ||
"source": [ | ||
"---\n", | ||
"\n", | ||
"- The End :-)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.