Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a Scikit-learn like train_test_split method #12

Merged
merged 4 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/plot.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# `splito.plot`

::: splito.plot
options:
filters: ["!^_"]
8 changes: 6 additions & 2 deletions docs/api/simpd.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# `splito.simpd`

::: splito.simpd.SIMPDSplitter

options:
filters: ["!^_"]
---

::: splito.simpd.run_SIMPD

options:
filters: ["!^_"]
---

::: splito.simpd.DEFAULT_SIMPD_DESCRIPTORS
options:
filters: ["!^_"]
12 changes: 12 additions & 0 deletions docs/api/splito.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# `splito`

## Basic usage

::: splito
options:
filters: ["train_test_split"]

---

## Advanced usage

::: splito
options:
filters: ["!^_", "!train_test_split"]
12 changes: 7 additions & 5 deletions splito/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ._mood_split import MOODSplitter
from ._distribution_split import StratifiedDistributionSplit
from ._kmeans_split import KMeansSplit
from ._perimeter_split import PerimeterSplit
from ._max_dissimilarity_split import MaxDissimilaritySplit
from ._scaffold_split import ScaffoldSplit
from ._min_max_split import MolecularMinMaxSplit
from ._molecular_weight import MolecularWeightSplit
from ._distribution_split import StratifiedDistributionSplit

from ._mood_split import MOODSplitter
from ._perimeter_split import PerimeterSplit
from ._scaffold_split import ScaffoldSplit
from ._split import train_test_split, train_test_split_indices

__all__ = [
"MOODSplitter",
Expand All @@ -17,4 +17,6 @@
"StratifiedDistributionSplit",
"MolecularWeightSplit",
"MolecularMinMaxSplit",
"train_test_split",
"train_test_split_indices",
]
33 changes: 24 additions & 9 deletions splito/_distance_split_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import abc
from typing import Callable, Optional, Sequence, Union

import numpy as np
import datamol as dm

from typing import Callable, Union, Optional, Sequence

import numpy as np
import pandas as pd
from numpy.random import RandomState
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import GroupShuffleSplit
Expand All @@ -13,12 +12,26 @@

from .utils import get_kmeans_clusters


# In case users provide a list of SMILES instead of features, we rely on ECFP4 and the tanimoto distance by default
MOLECULE_DEFAULT_FEATURIZER = dict(name="ecfp", kwargs=dict(radius=2, nBits=2048))
MOLECULE_DEFAULT_DISTANCE_METRIC = "jaccard"


def guess_distance_metric(example):
"""Guess the appropriate distance metric given an exemplary datapoint"""

# By default we use the Euclidean distance
metric = "euclidean"

# For binary vectors we use jaccard
if isinstance(example, pd.DataFrame):
example = example.values # DataFrames would require all().all() otherwise
if ((example == 0) | (example == 1)).all():
metric = "jaccard"

return metric


def convert_to_default_feats_if_smiles(
X: Union[Sequence[str], np.ndarray], metric: str, n_jobs: Optional[int] = None
):
Expand Down Expand Up @@ -46,7 +59,7 @@ class DistanceSplitBase(GroupShuffleSplit, abc.ABC):
def __init__(
self,
n_splits=10,
metric: Union[str, Callable] = "euclidean",
metric: Optional[Union[str, Callable]] = None,
n_jobs: Optional[int] = None,
test_size: Optional[Union[float, int]] = None,
train_size: Optional[Union[float, int]] = None,
Expand Down Expand Up @@ -108,10 +121,12 @@ def _iter_indices(
if base_seed is None:
base_seed = 0

for i in range(self.n_splits):
# Convert to ECFP4 if X is a list of smiles
X, self._metric = convert_to_default_feats_if_smiles(X, self._metric, n_jobs=self._n_jobs)
# Convert to ECFP4 if X is a list of smiles
X, self._metric = convert_to_default_feats_if_smiles(X, self._metric, n_jobs=self._n_jobs)
if self._metric is None:
self._metric = guess_distance_metric(X[0])

for i in range(self.n_splits):
# Possibly group the data to improve computation efficiency
groups = self.reduce(X, base_seed + i)

Expand Down
132 changes: 132 additions & 0 deletions splito/_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from enum import Enum, unique
from typing import Optional, Sequence, Union

import datamol as dm
import numpy as np
from sklearn.model_selection import ShuffleSplit

from ._distribution_split import StratifiedDistributionSplit
from ._kmeans_split import KMeansSplit
from ._max_dissimilarity_split import MaxDissimilaritySplit
from ._min_max_split import MolecularMinMaxSplit
from ._molecular_weight import MolecularWeightSplit
from ._perimeter_split import PerimeterSplit
from ._scaffold_split import ScaffoldSplit


@unique
class SimpleSplittingMethod(Enum):
RANDOM = ShuffleSplit
KMEANS = KMeansSplit
PERIMETER = PerimeterSplit
MAX_DISSIMILARITY = MaxDissimilaritySplit
SCAFFOLD = ScaffoldSplit
STRATIFIED_DISTRIBUTION = StratifiedDistributionSplit
MOLECULAR_WEIGHT = MolecularWeightSplit
MIN_MAX_DIVERSITY_SPLIT = MolecularMinMaxSplit


def train_test_split_indices(
X: np.ndarray,
y: np.ndarray,
molecules: Optional[Sequence[Union[str, dm.Mol]]] = None,
cwognum marked this conversation as resolved.
Show resolved Hide resolved
method: Union[str, SimpleSplittingMethod] = "random",
cwognum marked this conversation as resolved.
Show resolved Hide resolved
test_size: float = 0.2,
seed: int = None,
n_jobs: Optional[int] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Returns the indices of the train and test sets.

Different from scikit-learn's API, we assume some data-types are not represented as numpy arrays
and cannot be directly indexed as we do in [`train_test_split`][splito.train_test_split]. This
functions offers a way to just return the indices and take care of the split manually.

See [`train_test_split`][splito.train_test_split] for more information.
"""
X = np.array(X)
y = np.array(y)

method = SimpleSplittingMethod[method.upper()] if isinstance(method, str) else method

splitter_kwargs = {"test_size": test_size, "random_state": seed}
if method in [
SimpleSplittingMethod.MOLECULAR_WEIGHT,
SimpleSplittingMethod.MIN_MAX_DIVERSITY_SPLIT,
SimpleSplittingMethod.SCAFFOLD,
cwognum marked this conversation as resolved.
Show resolved Hide resolved
]:
if molecules is None:
raise ValueError(f"{method.name} requires a list of molecules to be provided.")
if isinstance(molecules[0], dm.Mol):
molecules = dm.utils.parallelized(dm.to_smiles, molecules, n_jobs=n_jobs)
splitter_kwargs["smiles"] = molecules

splitter_cls = method.value
splitter = splitter_cls(**splitter_kwargs)

return next(splitter.split(X, y))


def train_test_split(
X: np.ndarray,
y: np.ndarray,
molecules: Optional[Sequence[Union[str, dm.Mol]]] = None,
method: Union[str, SimpleSplittingMethod] = "random",
test_size: float = 0.2,
seed: int = None,
n_jobs: Optional[int] = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Splits a set of molecules into a train and test set.

Inspired by sklearn.model_selection.train_test_split, this function is meant as a convenience function
that provides a less verbose way of using the different splitters.

**Examples**:

Let's first create a toy dataset

```python
import datamol as dm
import numpy as np

data = dm.data.freesolv()
smiles = data["smiles"].values
X = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles])
y = data["expt"].values
```

Now we can split our data.

```python
X_train, X_test, y_train, y_test = train_test_split(X, y, method="random")
```

More parameters
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, method="random", test_size=0.1, random_state=42)
```

Scaffold split (note that you need to specify `smiles`):
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, smiles=smiles, method="scaffold")
```

Distance-based split:
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, method="kmeans")
```

Args:
X: The feature matrix.
y: The target values.
molecules: A list of molecules to be used for the split. Required for some splitting methods.
method: The splitting method to use. Defaults to "random".
test_size: The proportion of the dataset to include in the test split.
seed: The seed to use for the random number generator.
n_jobs: The number of jobs to run in parallel.
"""
train_indices, test_indices = train_test_split_indices(
X, y, molecules=molecules, method=method, test_size=test_size, seed=seed, n_jobs=n_jobs
)
return X[train_indices], X[test_indices], y[train_indices], y[test_indices]
52 changes: 39 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
import datamol as dm
import numpy as np
import pytest

import datamol as dm

@pytest.fixture(scope="module")
def test_dataset():
data = dm.data.freesolv()
data["mol"] = [dm.to_mol(smi) for smi in data["smiles"]]
data = data.dropna()
return data


@pytest.fixture(scope="module")
def test_dataset_smiles(test_dataset):
return test_dataset["smiles"].values


@pytest.fixture(scope="module")
def test_dataset_targets(test_dataset):
return test_dataset["expt"].values


@pytest.fixture(scope="module")
def test_dataset_features(test_dataset):
return np.array([dm.to_fp(mol) for mol in test_dataset["mol"].values])


@pytest.fixture(scope="module")
def test_data():
return dm.data.freesolv()[:100]
def test_deployment_set():
data = dm.data.solubility()
data["mol"] = [dm.to_mol(smi) for smi in data["smiles"]]
data = data.dropna()
return data


@pytest.fixture(scope="module")
def test_deployment_smiles(test_deployment_set):
return test_deployment_set["smiles"].values


@pytest.fixture(scope="module")
def test_deployment_features(test_deployment_set):
return np.array([dm.to_fp(mol) for mol in test_deployment_set["mol"].values])


@pytest.fixture(scope="module")
Expand All @@ -20,13 +56,3 @@ def manual_smiles():
"CN1C=NC2=C1C(=O)NC(=O)N2C",
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
]


@pytest.fixture(scope="module")
def dataset_smiles(test_data):
return test_data["smiles"].values


@pytest.fixture(scope="module")
def dataset_targets(test_data):
return test_data["expt"].values
6 changes: 3 additions & 3 deletions tests/test_distribution_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@


@pytest.mark.parametrize("algorithm", list(Clustering1D))
def test_splits_stratified_distribution(dataset_smiles, dataset_targets, algorithm):
def test_splits_stratified_distribution(test_dataset_smiles, test_dataset_targets, algorithm):
splitter = StratifiedDistributionSplit(algorithm=algorithm, n_splits=2)

for train_ind, test_ind in splitter.split(dataset_smiles, y=dataset_targets):
assert len(train_ind) + len(test_ind) == len(dataset_targets)
for train_ind, test_ind in splitter.split(test_dataset_smiles, y=test_dataset_targets):
assert len(train_ind) + len(test_ind) == len(test_dataset_targets)
assert len(set(train_ind).intersection(set(test_ind))) == 0
assert len(train_ind) > 0 and len(test_ind) > 0

Expand Down
6 changes: 3 additions & 3 deletions tests/test_kmeans_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from splito import KMeansSplit


def test_splits_kmeans_default_feats(dataset_smiles):
def test_splits_kmeans_default_feats(test_dataset_smiles):
splitter = KMeansSplit(n_splits=2)

for train_ind, test_ind in splitter.split(dataset_smiles):
assert len(train_ind) + len(test_ind) == len(dataset_smiles)
for train_ind, test_ind in splitter.split(test_dataset_smiles):
assert len(train_ind) + len(test_ind) == len(test_dataset_smiles)
assert len(set(train_ind).intersection(set(test_ind))) == 0
assert len(train_ind) > 0 and len(test_ind) > 0
assert splitter._cluster_metric == "jaccard"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_max_dissimilarity_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from splito import MaxDissimilaritySplit


def test_splits_max_dissimilar_default_feats(dataset_smiles):
def test_splits_max_dissimilar_default_feats(test_dataset_smiles):
splitter = MaxDissimilaritySplit(n_splits=2)

for train_ind, test_ind in splitter.split(dataset_smiles):
assert len(train_ind) + len(test_ind) == len(dataset_smiles)
for train_ind, test_ind in splitter.split(test_dataset_smiles):
assert len(train_ind) + len(test_ind) == len(test_dataset_smiles)
assert len(set(train_ind).intersection(set(test_ind))) == 0
assert len(train_ind) > 0 and len(test_ind) > 0

Expand Down
6 changes: 3 additions & 3 deletions tests/test_min_max_split.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from splito import MolecularMinMaxSplit


def test_splits_min_max(dataset_smiles):
def test_splits_min_max(test_dataset_smiles):
splitter = MolecularMinMaxSplit(n_splits=2)

for train_ind, test_ind in splitter.split(dataset_smiles):
assert len(train_ind) + len(test_ind) == len(dataset_smiles)
for train_ind, test_ind in splitter.split(test_dataset_smiles):
assert len(train_ind) + len(test_ind) == len(test_dataset_smiles)
assert len(set(train_ind).intersection(set(test_ind))) == 0
assert len(train_ind) > 0 and len(test_ind) > 0
Loading
Loading