From 621cc96cf58b2bf45084c72825a0026acc8f8bba Mon Sep 17 00:00:00 2001 From: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:14:47 +0100 Subject: [PATCH] First draft for KFold (#684) * make_split * make_kfold * KFold *SingleSplit --------- Co-authored-by: camillebrianceau Co-authored-by: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> --- clinicadl/API/complicated_case.py | 137 +++--- clinicadl/API/cross_val.py | 74 +-- .../commandline/modules_options/split.py | 9 +- clinicadl/dataset/dataloader/__init__.py | 1 + clinicadl/dataset/dataloader/config.py | 129 +++++ clinicadl/dataset/dataloader/defaults.py | 14 + clinicadl/dataset/datasets/caps_dataset.py | 49 +- clinicadl/dataset/utils.py | 56 ++- .../experiment_manager/experiment_manager.py | 9 +- clinicadl/experiment_manager/maps_manager.py | 6 - clinicadl/interpret/config.py | 2 +- clinicadl/model/clinicadl_model.py | 2 +- clinicadl/predictor/config.py | 2 +- clinicadl/predictor/old_predictor.py | 3 +- clinicadl/splitter/__init__.py | 3 + clinicadl/splitter/config.py | 71 --- clinicadl/splitter/kfold.py | 24 - clinicadl/splitter/make_splits/__init__.py | 2 + clinicadl/splitter/make_splits/kfold.py | 178 +++++++ .../splitter/make_splits/single_split.py | 439 ++++++++++++++++++ clinicadl/splitter/make_splits/utils.py | 78 ++++ clinicadl/splitter/old_splitter.py | 237 ---------- clinicadl/splitter/split.py | 169 ++++++- clinicadl/splitter/split_utils.py | 113 +++++ clinicadl/splitter/splitter/__init__.py | 2 + clinicadl/splitter/splitter/kfold.py | 108 +++++ clinicadl/splitter/splitter/single_split.py | 96 ++++ clinicadl/splitter/splitter/splitter.py | 256 ++++++++++ clinicadl/splitter/test.py | 168 +++++++ clinicadl/trainer/config/train.py | 4 +- clinicadl/trainer/old_trainer.py | 5 +- .../tsvtools/get_metadata/get_metadata.py | 2 + clinicadl/tsvtools/split/split.py | 29 -- clinicadl/tsvtools/tsvtools_utils.py | 3 +- tests/unittests/dataset/test_config.py | 2 +- .../split_test/split/2_fold/kfold_config.json | 8 + .../split_test/split/2_fold/split-0/train.tsv | 49 ++ .../split/2_fold/split-0/train_baseline.tsv | 49 ++ .../2_fold/split-0/validation_baseline.tsv | 49 ++ .../split_test/split/2_fold/split-1/train.tsv | 49 ++ .../split/2_fold/split-1/train_baseline.tsv | 49 ++ .../2_fold/split-1/validation_baseline.tsv | 49 ++ .../split_test/split/single_split_config.json | 15 + .../split/split_categorical_stats.tsv | 11 + .../split/split_continuous_stats.tsv | 5 + .../split_test/split/test_baseline.tsv | 25 + .../caps_example/split_test/split/train.tsv | 155 +++++++ .../split_test/split/train_baseline.tsv | 97 ++++ .../caps_example/subjects_false.tsv | 2 + .../caps_example/subjects_sessions_list.tsv | 8 - .../ressources/caps_example/subjects_t1.tsv | 65 +++ tests/unittests/splitter/test_make_split.py | 195 ++++++++ tests/unittests/splitter/test_splitter.py | 105 +++++ .../train/trainer/test_training_config.py | 2 - 54 files changed, 2923 insertions(+), 546 deletions(-) create mode 100644 clinicadl/dataset/dataloader/__init__.py create mode 100644 clinicadl/dataset/dataloader/config.py create mode 100644 clinicadl/dataset/dataloader/defaults.py delete mode 100644 clinicadl/splitter/config.py delete mode 100644 clinicadl/splitter/kfold.py create mode 100644 clinicadl/splitter/make_splits/__init__.py create mode 100644 clinicadl/splitter/make_splits/kfold.py create mode 100644 clinicadl/splitter/make_splits/single_split.py create mode 100644 clinicadl/splitter/make_splits/utils.py delete mode 100644 clinicadl/splitter/old_splitter.py create mode 100644 clinicadl/splitter/splitter/__init__.py create mode 100644 clinicadl/splitter/splitter/kfold.py create mode 100644 clinicadl/splitter/splitter/single_split.py create mode 100644 clinicadl/splitter/splitter/splitter.py create mode 100644 clinicadl/splitter/test.py create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/single_split_config.json create mode 100644 tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/subjects_false.tsv create mode 100644 tests/unittests/ressources/caps_example/subjects_t1.tsv create mode 100644 tests/unittests/splitter/test_make_split.py create mode 100644 tests/unittests/splitter/test_splitter.py diff --git a/clinicadl/API/complicated_case.py b/clinicadl/API/complicated_case.py index aa126629b..4afe5050a 100644 --- a/clinicadl/API/complicated_case.py +++ b/clinicadl/API/complicated_case.py @@ -2,19 +2,15 @@ import torchio.transforms as transforms -from clinicadl.dataset.caps_dataset import ( - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, -) from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.datasets.concat import ConcatDataset +from clinicadl.dataset.preprocessing import ( + PreprocessingCustom, + PreprocessingPET, + PreprocessingT1, +) +from clinicadl.dataset.readers.caps_reader import CapsReader from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig from clinicadl.losses.factory import get_loss_function @@ -31,6 +27,7 @@ from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split import get_single_split, split_tsv from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice from clinicadl.transforms.transforms import Transforms # Create the Maps Manager / Read/write manager / @@ -40,60 +37,47 @@ ) # a ajouter dans le manager: mlflow/ profiler/ etc ... caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) - -preprocessing_1 = caps_reader.get_preprocessing("t1-linear") -caps_reader.prepare_data( - preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 -) # don't return anything -> just extract the image tensor and compute some information for each images - -transforms_1 = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractionPatchConfig(patch_size=3), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], -) # not mandatory - -preprocessing_2 = caps_reader.get_preprocessing("pet-linear") -caps_reader.prepare_data( - preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 -) # to extract the tensor of the PET file this time -transforms_2 = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractionSliceConfig(), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], + +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +preprocessing_t1 = PreprocessingT1() +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) + +print("T1 and image ") + +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_t1, + preprocessing=preprocessing_t1, + transforms=transforms_image, +) +dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the T1 file + + +sub_ses_pet_45 = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv" ) +preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") -sub_ses_tsv_1 = Path("") -split_dir_1 = split_tsv(sub_ses_tsv_1) # -> creer un test.tsv et un train.tsv - -sub_ses_tsv_2 = Path("") -split_dir_2 = split_tsv(sub_ses_tsv_2) - -dataset_t1_roi = caps_reader.get_dataset( - preprocessing=preprocessing_1, - sub_ses_tsv=split_dir_1 / "train.tsv", - transforms=transforms_1, -) # do we give config or object for transforms ? -dataset_pet_patch = caps_reader.get_dataset( - preprocessing=preprocessing_2, - sub_ses_tsv=split_dir_2 / "train.tsv", - transforms=transforms_2, +dataset_pet_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_pet_45, + preprocessing=preprocessing_pet_45, + transforms=transforms_image, ) +dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the PET file + -dataset_multi_modality_multi_extract = ConcatDataset( +dataset_multi_modality = ConcatDataset( [ - dataset_t1_roi, - dataset_pet_patch, - caps_reader.get_dataset_from_json(json_path=Path("dataset.json")), + dataset_t1_image, + dataset_pet_image, ] ) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention -# TODO : think about adding transforms in extract_json - config_file = Path("config_file") trainer = Trainer.from_json(config_file=config_file, manager=manager) @@ -106,6 +90,26 @@ dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) + +# CAS 1 + +# Prérequis : déjà avoir des fichiers avec les listes train et validation +split_dir = make_kfold( + "dataset.tsv" +) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir +splitter = KFolder( + dataset_multi_modality, split_dir +) # c'est plutôt un iterable de dataloader + +# CAS 2 +splitter = KFolder(caps_dataset=dataset_t1_image) +splitter.make_splits(n_splits=3) +splitter.write(split_dir) + +# or +splitter = KFolder(caps_dataset=dataset_t1_image) +splitter.read(split_dir) + for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): # bien définir ce qu'il y a dans l'objet split @@ -125,19 +129,12 @@ # TEST -preprocessing_test = caps_reader.get_preprocessing("pet-linear") -transforms_test = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractioImageConfig(), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], -) -dataset_test = caps_reader.get_dataset( - preprocessing=preprocessing_test, - sub_ses_tsv=split_dir_1 / "test.tsv", # test only on data from the first dataset - transforms=transforms_test, +dataset_test = CapsDataset( + caps_directory=caps_directory, + preprocessing=preprocessing_t1, + sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset + transforms=transforms_image, ) predictor = Predictor(model=model, manager=manager) diff --git a/clinicadl/API/cross_val.py b/clinicadl/API/cross_val.py index 0efa2f195..54f7b9d6b 100644 --- a/clinicadl/API/cross_val.py +++ b/clinicadl/API/cross_val.py @@ -1,80 +1,38 @@ from pathlib import Path -import torchio.transforms as transforms - -from clinicadl.dataset.caps_dataset import ( - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.datasets.caps_dataset import CapsDataset from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.losses.config import CrossEntropyLossConfig -from clinicadl.losses.factory import get_loss_function -from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.networks.config import ImplementedNetworks -from clinicadl.networks.factory import ( - ConvEncoderOptions, - create_network_config, - get_network_from_config, -) -from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig -from clinicadl.optimization.optimizer.factory import get_optimizer from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig +from clinicadl.splitter.new_splitter.splitter.kfold import KFold from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.config import TransformsConfig # SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING maps_path = Path("/") manager = ExperimentManager(maps_path, overwrite=False) -dataset_t1_image = CapsDatasetPatch.from_json( - extraction=Path("test.json"), - sub_ses_tsv=Path("split_dir") / "train.tsv", -) +dataset_t1_image = CapsDataset.from_json(Path("json_path.json")) + config_file = Path("config_file") trainer = Trainer.from_json( config_file=config_file, manager=manager ) # gpu, amp, fsdp, seed -# CAS CROSS-VALIDATION -splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager) -split_dir = splitter.make_splits( - n_splits=3, - output_dir=Path(""), - data_tsv=Path("labels.tsv"), - subset_name="validation", - stratification="", -) # Optional data tsv and output_dir -# n_splits must be >1 -# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here +splitter = KFold(dataset=dataset_t1_image) +splitter.make_splits(n_splits=3) +split_dir = Path("") +splitter.write(split_dir) -# CAS EXISTING CROSS-VALIDATION -splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager) +splitter.read(split_dir) # define the needed parameters for the dataloader -dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) +dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10) -for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): - # bien définir ce qu'il y a dans l'objet split - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer) +for split in splitter.get_splits(splits=(0, 3, 4)): + print(split) + split.build_train_loader(dataloader_config) + split.build_val_loader(num_workers=3, batch_size=10) - trainer.train(model, split) - # le trainer va instancier un predictor/valdiator dans le train ou dans le init + print(split) diff --git a/clinicadl/commandline/modules_options/split.py b/clinicadl/commandline/modules_options/split.py index f7c0a8882..4579a6f7f 100644 --- a/clinicadl/commandline/modules_options/split.py +++ b/clinicadl/commandline/modules_options/split.py @@ -2,13 +2,14 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.kfold import KFoldConfig +from clinicadl.splitter.splitter.single_split import SingleSplitConfig # Cross Validation n_splits = click.option( "--n_splits", - type=get_type("n_splits", SplitConfig), - default=get_default("n_splits", SplitConfig), + type=get_type("n_splits", KFoldConfig), + default=get_default("n_splits", KFoldConfig), help="If a value is given for k will load data of a k-fold CV. " "Default value (0) will load a single split.", show_default=True, @@ -17,7 +18,7 @@ "--split", "-s", type=int, # get_type("split", config.ValidationConfig), - default=get_default("split", SplitConfig), + default=get_default("split", SingleSplitConfig), multiple=True, help="Train the list of given splits. By default, all the splits are trained.", show_default=True, diff --git a/clinicadl/dataset/dataloader/__init__.py b/clinicadl/dataset/dataloader/__init__.py new file mode 100644 index 000000000..eda7a269a --- /dev/null +++ b/clinicadl/dataset/dataloader/__init__.py @@ -0,0 +1 @@ +from .config import DataLoaderConfig diff --git a/clinicadl/dataset/dataloader/config.py b/clinicadl/dataset/dataloader/config.py new file mode 100644 index 000000000..7c4acfb6d --- /dev/null +++ b/clinicadl/dataset/dataloader/config.py @@ -0,0 +1,129 @@ +from typing import Optional + +from pydantic import NonNegativeInt, PositiveInt +from torch.utils.data import DataLoader, DistributedSampler, Sampler +from torch.utils.data import WeightedRandomSampler as BaseWeightedRandomSampler + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.utils.config import ClinicaDLConfig +from clinicadl.utils.seed import pl_worker_init_function + +from .defaults import ( + BATCH_SIZE, + DP_DEGREE, + DROP_LAST, + NUM_WORKERS, + PIN_MEMORY, + PREFETCH_FACTOR, + RANK, + SAMPLING_WEIGHTS, + SHUFFLE, +) + + +class WeightedRandomSampler(BaseWeightedRandomSampler): + """ + Modifies PyTorch's WeightedRandomSampler to have a similar behavior to + PyTorch's DistributedSampler. + """ + + def set_epoch(self, epoch: int) -> None: + """ + Fake method to simulate 'set_epoch' of PyTorch's DistributedSampler. + To be able to always call sampler.set_epoch(), no matter the sampler. + """ + + +class DataLoaderConfig(ClinicaDLConfig): + """Config class to parametrize a PyTorch DataLoader.""" + + batch_size: PositiveInt = BATCH_SIZE + sampling_weights: Optional[str] = SAMPLING_WEIGHTS + shuffle: bool = SHUFFLE + drop_last: bool = DROP_LAST + num_workers: NonNegativeInt = NUM_WORKERS + prefetch_factor: Optional[NonNegativeInt] = PREFETCH_FACTOR + pin_memory: bool = PIN_MEMORY + + def _generate_sampler( + self, + dataset: CapsDataset, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> Sampler: + """ + Returns a WeightedRandomSampler if self.sampling_weights is not None, otherwise a + a DistributedSampler, even when data parallelism is not performed (in this case + the degree of data parallelism is set to 1, so it is equivalent to a simple PyTorch + RandomSampler if self.shuffle is True or no sampler if self.shuffle is False). + """ + if (rank is not None and dp_degree is None) or ( + dp_degree is not None and rank is None + ): + raise ValueError( + "For data parallelism, none of 'dp_degree' and 'rank' can be None. " + f"Got rank={rank} and dp_degree={dp_degree}" + ) + distributed = dp_degree is not None + + if self.sampling_weights: + try: + weights = dataset.df[self.sampling_weights].values.astype(float) + except KeyError as exc: + raise KeyError( + f"Got {self.sampling_weights} for 'sampling_weights' but there is no " + "column named like that in the dataframe of the dataset." + ) from exc + length = ( + len(weights) // dp_degree + int(rank < len(weights) % dp_degree) + if distributed + else len(weights) + ) + sampler = WeightedRandomSampler(weights, num_samples=length) # type: ignore + else: + if not distributed: + dp_degree = 1 + rank = 0 + sampler = DistributedSampler( + dataset, + num_replicas=dp_degree, + rank=rank, + shuffle=self.shuffle, + drop_last=False, # not the same as self.drop_last + ) + + return sampler + + def get_dataloader( + self, + dataset: CapsDataset, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> DataLoader: + """ + To get a dataloader from a dataset. The dataloader is parametrized + with the options stored in this configuration class. + + Parameters + ---------- + dataset : CapsDataset + The dataset to put in a Dataloader. + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Returns + ------- + DataLoader + The dataloader that wraps the dataset. + """ + loader = DataLoader( + dataset=dataset, + sampler=self._generate_sampler(dataset, dp_degree, rank), + worker_init_fn=pl_worker_init_function, + **self.model_dump(), + ) + + return loader diff --git a/clinicadl/dataset/dataloader/defaults.py b/clinicadl/dataset/dataloader/defaults.py new file mode 100644 index 000000000..ed493d9fb --- /dev/null +++ b/clinicadl/dataset/dataloader/defaults.py @@ -0,0 +1,14 @@ +BATCH_SIZE = 1 + +SAMPLING_WEIGHTS = None # no weighted sampling + +SHUFFLE = False +DROP_LAST = False + +NUM_WORKERS = 0 # main process loads data +PREFETCH_FACTOR = None + +PIN_MEMORY = True # training is supposed to be on a GPU + +DP_DEGREE = None # no data parallelism +RANK = None diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/dataset/datasets/caps_dataset.py index 95535d73f..5a877c8d2 100644 --- a/clinicadl/dataset/datasets/caps_dataset.py +++ b/clinicadl/dataset/datasets/caps_dataset.py @@ -1,5 +1,7 @@ # coding: utf8 +from __future__ import annotations +from copy import deepcopy from logging import getLogger from pathlib import Path from typing import List, Optional, Tuple, Union @@ -12,6 +14,7 @@ from torch import save as save_tensor from torch.utils.data import Dataset from tqdm import tqdm +from typing_extensions import Self from clinicadl.dataset.preprocessing import BasePreprocessing from clinicadl.dataset.readers.caps_reader import CapsReader @@ -190,7 +193,18 @@ def _get_df_from_input( ) logger.info(f"Creating a subject session TSV file at {data}") - elif isinstance(data, str): + df = self._check_data_instance(data) + self.df = df + + if not self._check_preprocessing_config(): + raise ClinicaDLCAPSError( + f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" + ) + + return df + + def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path]] = None): + if isinstance(data, str): data = Path(data) if isinstance(data, Path): @@ -200,15 +214,9 @@ def _get_df_from_input( "Please ensure the file path is correct and accessible." ) df = tsv_to_df(data) - elif isinstance(data, pd.DataFrame): + if isinstance(data, pd.DataFrame): df = check_df(data) - self.df = df - if not self._check_preprocessing_config(): - raise ClinicaDLCAPSError( - f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" - ) - return df def _check_preprocessing_config(self) -> bool: @@ -535,3 +543,28 @@ def prepare_image(participant, session): self._get_participants_sessions_couple(), desc="Preparing data" ) ) + + def subset(self, data: Optional[Union[pd.DataFrame, Path]] = None) -> CapsDataset: + df = self._check_data_instance(data) + + common_rows = pd.merge(df, self.df, how="inner") + all_included = len(common_rows) == len(df) + + if not all_included: + missing_rows = pd.concat( + [df, common_rows], ignore_index=True + ).drop_duplicates(keep=False) + + err_message = "Missing rows: \n" + for row in missing_rows: + err_message += f" - {row} \n" + + raise ClinicaDLTSVError( + "Some couples (participanst_id, session_id) are not in the dataset,", + err_message, + ) + + dataset = deepcopy(self) + dataset.df = df + + return dataset diff --git a/clinicadl/dataset/utils.py b/clinicadl/dataset/utils.py index 6316b00bd..a1f2c1889 100644 --- a/clinicadl/dataset/utils.py +++ b/clinicadl/dataset/utils.py @@ -6,6 +6,7 @@ import pandas as pd import torch +import torchio as tio from pydantic import BaseModel, ConfigDict from clinicadl.dataset import preprocessing @@ -51,6 +52,34 @@ class CapsDatasetSample(BaseModel): model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) +def df_to_tsv(name: str, results_path: Path, df, baseline: bool = False) -> None: + """ + Write Dataframe into a TSV file and drop duplicates + + Parameters + ---------- + name: str + Name of the tsv file + results_path: str (path) + Path to the folder + df: DataFrame + DataFrame you want to write in a TSV file. + Columns must include ["participant_id", "session_id"]. + baseline: bool + If True, there is only baseline session for each subject. + """ + + df.sort_values(by=["participant_id", "session_id"], inplace=True) + if baseline: + df.drop_duplicates(subset=["participant_id"], keep="first", inplace=True) + else: + df.drop_duplicates( + subset=["participant_id", "session_id"], keep="first", inplace=True + ) + # df = df[["participant_id", "session_id"]] + df.to_csv(results_path / name, sep="\t", index=False) + + def tsv_to_df(tsv_path: Path) -> pd.DataFrame: """ Converts a TSV file to a Pandas DataFrame. @@ -94,7 +123,32 @@ def check_df(df: pd.DataFrame) -> pd.DataFrame: f"The data file is not in the correct format. " f"Columns should include {PARTICIPANT_ID, SESSION_ID}" ) - df.reset_index(inplace=True) + + return df + + +def reset_index(df: pd.DataFrame) -> pd.DataFrame: + """ + Resets the index of a DataFrame to the default index, dropping any existing index. + + Args: + df (pd.DataFrame): The DataFrame to be reset. + + Returns: + pd.DataFrame: The DataFrame with the default index. + + Note: + This function only resets the index if the DataFrame has a MultiIndex with the 'participant_id' and'session_id' names. + If the DataFrame does not have this MultiIndex, the 'drop' parameter is set to True, which results in dropping the index. + """ + + drop = False + if isinstance(df.index, pd.MultiIndex): + if set(df.index.names) != {PARTICIPANT_ID, SESSION_ID}: + drop = True + + df.reset_index(inplace=True, drop=drop) + return df diff --git a/clinicadl/experiment_manager/experiment_manager.py b/clinicadl/experiment_manager/experiment_manager.py index 6edad60d6..dea504e6e 100644 --- a/clinicadl/experiment_manager/experiment_manager.py +++ b/clinicadl/experiment_manager/experiment_manager.py @@ -8,15 +8,14 @@ import pandas as pd from pydantic import BaseModel -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.preprocessing import PreprocessingConfig +from clinicadl.dataset.preprocessing import BasePreprocessing +from clinicadl.dataset.readers import CapsReader from clinicadl.metrics.old_metrics.utils import check_selection_metric from clinicadl.model.clinicadl_model import ClinicaDLModel from clinicadl.networks.config import NetworkConfig from clinicadl.networks.factory import get_network_from_config -from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split_utils import print_description_log +from clinicadl.transforms.extraction import Extraction from clinicadl.utils.exceptions import MAPSError from clinicadl.utils.iotools.data_utils import load_data_test from clinicadl.utils.iotools.utils import path_decoder, path_encoder @@ -71,7 +70,7 @@ def information_log(self) -> Path: def get_info_from_json( self, - ) -> tuple[PreprocessingConfig, ExtractionConfig, CapsReader, ClinicaDLModel]: + ) -> tuple[PreprocessingConfig, Extraction, CapsReader, ClinicaDLModel]: """Reads the maps.json file and returns its content.""" # I don't know if this is a useful function if self.maps_json.is_file(): diff --git a/clinicadl/experiment_manager/maps_manager.py b/clinicadl/experiment_manager/maps_manager.py index c8c9a8eb4..0a3903136 100644 --- a/clinicadl/experiment_manager/maps_manager.py +++ b/clinicadl/experiment_manager/maps_manager.py @@ -9,17 +9,11 @@ import pandas as pd import torch -from clinicadl.dataset.caps_dataset import ( - return_dataset, -) -from clinicadl.dataset.caps_dataset.caps_dataset_utils import read_json from clinicadl.metrics.old_metrics.metric_module import MetricModule from clinicadl.metrics.old_metrics.utils import ( check_selection_metric, ) from clinicadl.predictor.utils import get_prediction -from clinicadl.splitter.config import SplitterConfig -from clinicadl.splitter.old_splitter import Splitter from clinicadl.trainer.tasks_utils import ( ensemble_prediction, evaluation_metrics, diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index d5cf76390..5a1ebeaf3 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -11,7 +11,7 @@ from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod diff --git a/clinicadl/model/clinicadl_model.py b/clinicadl/model/clinicadl_model.py index 8785ed97b..27adad432 100644 --- a/clinicadl/model/clinicadl_model.py +++ b/clinicadl/model/clinicadl_model.py @@ -3,6 +3,6 @@ class ClinicaDLModel: - def __init__(self, network: nn.Module, loss: nn.Module, optimizer=optim.optimizer): + def __init__(self, network: nn.Module, loss: nn.Module, optimizer): """TO COMPLETE""" pass diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py index eaa3a8653..2c793045c 100644 --- a/clinicadl/predictor/config.py +++ b/clinicadl/predictor/config.py @@ -9,7 +9,7 @@ ) from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import Task diff --git a/clinicadl/predictor/old_predictor.py b/clinicadl/predictor/old_predictor.py index 8314ce9d9..96d4764a5 100644 --- a/clinicadl/predictor/old_predictor.py +++ b/clinicadl/predictor/old_predictor.py @@ -48,8 +48,7 @@ class Predictor: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: self._config = _config - from clinicadl.splitter.config import SplitterConfig - from clinicadl.splitter.old_splitter import Splitter + from clinicadl.splitter.splitter.splitter import Splitter, SplitterConfig self.maps_manager = MapsManager(_config.maps_manager.maps_dir) self._config.adapt_with_maps_manager_info(self.maps_manager) diff --git a/clinicadl/splitter/__init__.py b/clinicadl/splitter/__init__.py index e69de29bb..40a3cf136 100644 --- a/clinicadl/splitter/__init__.py +++ b/clinicadl/splitter/__init__.py @@ -0,0 +1,3 @@ +from .make_splits import make_kfold, make_split +from .split import Split +from .splitter import KFold, SingleSplit diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py deleted file mode 100644 index 050b3c6c1..000000000 --- a/clinicadl/splitter/config.py +++ /dev/null @@ -1,71 +0,0 @@ -from abc import ABC, abstractmethod -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from pydantic import BaseModel, ConfigDict, field_validator -from pydantic.types import NonNegativeInt - -from clinicadl.dataset.config.data import DataConfig -from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.split_utils import find_splits - -logger = getLogger("clinicadl.split_config") - - -class SplitConfig(BaseModel): - """ - Abstract config class for the validation procedure. - - selection_metrics is specific to the task, thus it needs - to be specified in a subclass. - """ - - n_splits: NonNegativeInt = 0 - split: Optional[Tuple[NonNegativeInt, ...]] = None - tsv_path: Optional[Path] = None # not needed in interpret ! - - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @field_validator("split", mode="before") - def validator_split(cls, v): - if isinstance(v, list): - return tuple(v) - return v # TODO : check that split exists (and check coherence with n_splits) - - def adapt_cross_val_with_maps_manager_info( - self, maps_manager - ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future - # TEMPORARY - if not self.split: - self.split = tuple(find_splits(maps_manager.maps_path)) - logger.debug(f"List of splits {self.split}") - - -class SplitterConfig(BaseModel, ABC): - """ - - Abstract config class for the training pipeline. - Some configurations are specific to the task (e.g. loss function), - thus they need to be specified in a subclass. - """ - - data: DataConfig - split: SplitConfig - validation: ValidationConfig - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - def __init__(self, **kwargs): - super().__init__( - data=kwargs, - split=kwargs, - validation=kwargs, - ) - - def _update(self, config_dict: Dict[str, Any]) -> None: - """Updates the configs with a dict given by the user.""" - self.data.__dict__.update(config_dict) - self.split.__dict__.update(config_dict) - self.validation.__dict__.update(config_dict) diff --git a/clinicadl/splitter/kfold.py b/clinicadl/splitter/kfold.py deleted file mode 100644 index 37805dfba..000000000 --- a/clinicadl/splitter/kfold.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional - -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.experiment_manager.experiment_manager import ExperimentManager - - -class Split: - def __init__( - self, - ): - """TO COMPLETE""" - pass - - -class KFolder: - def __init__( - self, n_splits: int, caps_dataset: CapsDataset, manager: ExperimentManager - ) -> None: - """TO COMPLETE""" - - def split_iterator(self, split_list: Optional[list] = None) -> list[Split]: - """TO COMPLETE""" - - return list[Split()] diff --git a/clinicadl/splitter/make_splits/__init__.py b/clinicadl/splitter/make_splits/__init__.py new file mode 100644 index 000000000..00578866b --- /dev/null +++ b/clinicadl/splitter/make_splits/__init__.py @@ -0,0 +1,2 @@ +from .kfold import make_kfold +from .single_split import make_split diff --git a/clinicadl/splitter/make_splits/kfold.py b/clinicadl/splitter/make_splits/kfold.py new file mode 100644 index 000000000..ed9dd41e2 --- /dev/null +++ b/clinicadl/splitter/make_splits/kfold.py @@ -0,0 +1,178 @@ +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from pydantic import PositiveInt +from sklearn.model_selection import KFold, StratifiedKFold + +from clinicadl.dataset.utils import tsv_to_df +from clinicadl.splitter.make_splits.utils import write_to_csv +from clinicadl.splitter.splitter.kfold import KFoldConfig +from clinicadl.tsvtools.tsvtools_utils import extract_baseline +from clinicadl.utils.exceptions import ClinicaDLConfigurationError + + +def _validate_stratification( + df: pd.DataFrame, + stratification: Union[str, bool], +) -> Optional[str]: + """ + Validates and checks the stratification columns. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[str, bool] + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + Optional[str] + Validated stratification column or None if no stratification is applied. + + Raises + ------ + ClinicaDLConfigurationError + If invalid or conflicting stratification options are provided. + """ + if isinstance(stratification, bool): + if stratification: + stratification = "sex" + else: + return None + + if isinstance(stratification, List): + if len(stratification) > 1: + raise ClinicaDLConfigurationError( + "Stratification can only be performed on a single column for K-Fold splitting." + ) + else: + stratification = stratification[0] + + if isinstance(stratification, str): + if stratification not in df.columns: + raise ClinicaDLConfigurationError( + f"Stratification column '{stratification}' not found in the dataset." + ) + + if pd.api.types.is_numeric_dtype(df[stratification]) and df[ + stratification + ].nunique() >= (len(df) / 2): + raise ValueError( + "Continuous variables cannot be used for stratification in K-Fold splitting." + ) + return stratification + + raise ClinicaDLConfigurationError( + "Invalid or conflicting stratification options provided. Stratification must be a single column name or boolean." + ) + + +def preprocess_stratification( + df: pd.DataFrame, + stratification: Union[str, bool], +) -> pd.DataFrame: + """ + Preprocess stratification columns by creating labels for each subject. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[str, bool] + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + List[str] + List of stratification labels for the dataset. + """ + column = _validate_stratification(df, stratification) + + if column is None: + return df + + return df[[column]] + + +def make_kfold( + tsv_path: Path, + output_dir: Optional[Union[Path, str]] = None, + subset_name: str = "validation", + valid_longitudinal: bool = False, + n_splits: PositiveInt = 5, + stratification: Union[str, bool] = False, +) -> Path: + """ + Perform K-Fold splitting with optional stratification. + + Parameters + ---------- + tsv_path : Path + Path to the input TSV file. + output_dir : Optional[Path] + Directory to save the split files. Defaults to the parent directory of `tsv_path`. + subset_name : str, default="validation" + Name of the subset used for output files. + valid_longitudinal : bool, default=False + Whether to include longitudinal sessions in the split. + n_splits : PositiveInt, default=5 + Number of splits for K-Fold. + stratification : Union[str, bool], default=False + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + Path + Directory containing the generated split files. + + Raises + ------ + ClinicaDLConfigurationError + If invalid configuration options are provided. + """ + + # Set default output directory + output_dir = output_dir or tsv_path.parent + output_dir = Path(output_dir) + + # Initialize KFold configuration + config = KFoldConfig( + split_dir=output_dir, + subset_name=subset_name, + valid_longitudinal=valid_longitudinal, + n_splits=n_splits, + stratification=stratification, + ) + + config._check_split_dir() + config._write_json() + + # Load and process dataset + df = tsv_to_df(tsv_path) + baseline_df = extract_baseline(df) + + stratify_labels = preprocess_stratification( + df=baseline_df, + stratification=config.stratification, + ) + + # Create K-Fold splits + if config.stratification: + skf = StratifiedKFold(n_splits=config.n_splits, shuffle=True, random_state=2) + else: + skf = KFold(n_splits=config.n_splits, shuffle=True, random_state=2) + + for i, (train_idx, test_idx) in enumerate(skf.split(baseline_df, stratify_labels)): + train = baseline_df.iloc[train_idx] + test = baseline_df.iloc[test_idx] + + split_dir = config.split_dir / f"split-{i}" + split_dir.mkdir(parents=True, exist_ok=True) + + write_to_csv(test, split_dir, df, config.subset_name, config.valid_longitudinal) + write_to_csv(train, split_dir, df) + + return config.split_dir diff --git a/clinicadl/splitter/make_splits/single_split.py b/clinicadl/splitter/make_splits/single_split.py new file mode 100644 index 000000000..1854b5449 --- /dev/null +++ b/clinicadl/splitter/make_splits/single_split.py @@ -0,0 +1,439 @@ +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from pydantic import PositiveFloat +from scipy.stats import chisquare, ks_2samp, ttest_ind +from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit + +from clinicadl.dataset.utils import tsv_to_df +from clinicadl.splitter.make_splits.utils import write_to_csv +from clinicadl.splitter.splitter.single_split import SingleSplitConfig +from clinicadl.tsvtools.tsvtools_utils import extract_baseline +from clinicadl.utils.exceptions import ClinicaDLConfigurationError, ClinicaDLTSVError + +logger = getLogger("clinicadl.splitter.single_split") + + +def _validate_stratification( + df: pd.DataFrame, + stratification: Union[List[str], bool], +) -> List[str]: + """ + Checks and validates the specified stratification columns. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[List[str], bool] + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + + Returns + ------- + List[str], optional + Validated list of stratification columns or None if no stratification is applied. + + Raises + ------ + ValueError + If specified stratification columns are missing or if stratification conflicts with demographic handling. + ClinicaDLTSVError + If required demographic columns ('age', 'sex') are missing when not ignored. + """ + + if isinstance(stratification, bool): + if stratification: + stratification = ["age", "sex"] + else: + return [] + + if isinstance(stratification, list): + if not set(stratification).issubset(df.columns): + raise ValueError( + f"Invalid stratification columns: {set(stratification) - set(df.columns)}" + ) + return stratification + + raise ValueError( + "Invalid stratification option. Stratification must be a list of column names or a boolean." + ) + + +def _categorize_labels( + df: pd.DataFrame, + stratification: Union[List[str], bool], + n_test: int = 100, +) -> Tuple[List[str], List[str]]: + """ + Categorize stratification columns into continuous and categorical labels. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[List[str], bool] + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + n_test : int + Number of test samples. + + Returns + ------- + Tuple[List[str], List[str]] + Continuous and categorical labels. + """ + columns = _validate_stratification(df, stratification) + + continuous_labels, categorical_labels = [], [] + for col in columns: + if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() >= (n_test / 2): + continuous_labels.append(col) + else: + categorical_labels.append(col) + return continuous_labels, categorical_labels + + +def _chi2_test(x_test: List[int], x_train: List[int]) -> float: + """ + Perform the Chi-squared test on categorical data. + + Parameters + ---------- + x_test : np.ndarray + Test data. + x_train : np.ndarray + Train data. + + Returns + ------- + float + p-value from the Chi-squared test. + """ + unique_categories = np.unique(np.concatenate([x_test, x_train])) + + # Calculate observed (test) and expected (train) frequencies as raw counts + f_obs = np.array([(x_test == category).sum() for category in unique_categories]) + f_exp = np.array( + [ + (x_train == category).sum() / len(x_train) * len(x_test) + for category in unique_categories + ] + ) + + _, p_value = chisquare(f_obs, f_exp) + + return p_value + + +def make_split( + tsv_path: Path, + output_dir: Optional[Union[Path, str]] = None, + n_test: PositiveFloat = 100, + subset_name: str = "test", + p_categorical_threshold: float = 0.50, + p_continuous_threshold: float = 0.50, + stratification: Union[List[str], bool] = False, + valid_longitudinal=False, + n_try_max: int = 1000, +): + """ + Perform a single train-test split of the dataset with stratification. + + Parameters + ---------- + tsv_path : Path + Path to the input TSV file. + output_dir : Optional[Path] + Directory to save the split files. + n_test : PositiveFloat + If >= 1, specifies the absolute number of test samples. If < 1, treated as a proportion of the dataset. + subset_name : str + Name for the test subset. + p_categorical_threshold : float + Threshold for acceptable categorical stratification. + p_continuous_threshold : float + Threshold for acceptable continuous stratification. + stratification : Union[List[str], bool], default=False + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + valid_longitudinal : bool + Include longitudinal sessions if True. + n_try_max : int + Maximum number of attempts to find a valid split. + + Returns + ------- + Path + Directory containing the split files. + """ + + # Set default output directory + output_dir = output_dir or tsv_path.parent + output_dir = Path(output_dir) + + # Load dataset and preprocess + df = tsv_to_df(tsv_path) + baseline_df = extract_baseline(df) + + n_test = int(n_test) if n_test >= 1 else int(n_test * len(baseline_df)) + + continuous_labels, categorical_labels = _categorize_labels( + df=baseline_df, + stratification=stratification, + n_test=n_test, + ) + + # Initialize SingleSplit configuration + config = SingleSplitConfig( + split_dir=output_dir, + subset_name=subset_name, + valid_longitudinal=valid_longitudinal, + n_test=n_test, + p_continuous_threshold=p_continuous_threshold, + p_categorical_threshold=p_categorical_threshold, + stratification=stratification, + ) + + config._check_split_dir() + config._write_json() + + if config.n_test > 0: + splits = ShuffleSplit( + n_splits=n_try_max, test_size=config.n_test, random_state=2 + ) + for n_try, (train_index, test_index) in enumerate( + splits.split(baseline_df, baseline_df) + ): + p_continuous = compute_continuous_p_value( + continuous_labels, + baseline_df, + train_index.tolist(), + test_index.tolist(), + ) + + if p_continuous >= p_continuous_threshold: + p_categorical = compute_categorical_p_value( + categorical_labels, + baseline_df, + train_index.tolist(), + test_index.tolist(), + ) + + if p_categorical >= p_categorical_threshold: + logger.info(f"Valid split found after {n_try} attempts.") + + test_df = baseline_df.loc[test_index] + train_df = baseline_df.loc[train_index] + + write_continuous_stats( + config.split_dir / "split_continuous_stats.tsv", + continuous_labels, + test_df, + train_df, + subset_name, + ) + write_categorical_stats( + config.split_dir / "split_categorical_stats.tsv", + categorical_labels, + test_df, + train_df, + baseline_df, + subset_name, + ) + break + + if n_try >= n_try_max - 1: + raise ClinicaDLConfigurationError( + f"Unable to find a valid split after {n_try} attempts. " + f"Consider lowering thresholds or reducing stratification variables." + ) + + write_to_csv(test_df, config.split_dir, df, subset_name, valid_longitudinal) + else: + train_df = baseline_df + + write_to_csv(train_df, config.split_dir, df) + + return config.split_dir + + +def compute_continuous_p_value( + continuous_labels: Optional[list[str]], + baseline_df: pd.DataFrame, + train_index: list[int], + test_index: list[int], +) -> float: + """ + Compute the minimum p-value for continuous variables between train and test splits. + + Parameters + ---------- + continuous_labels : Optional[List[str]] + List of continuous variable names. + baseline_df : pd.DataFrame + Dataframe containing the baseline data. + train_index : List[int] + Indices for the training set. + test_index : List[int] + Indices for the testing set. + + Returns + ------- + float + The minimum p-value across all continuous labels. + """ + + p_continuous = 1.0 + if continuous_labels: + for label in continuous_labels: + if len(baseline_df[label] != 1): + train_values = baseline_df[label].loc[train_index].values.tolist() + test_values = baseline_df[label].loc[test_index].values.tolist() + + _, new_p_continuous = ttest_ind( + test_values, train_values, nan_policy="omit" + ) # ks_2samp, or ttost_ind from statsmodels.stats.weightstats import ttost_ind + + # Track the minimum p-value + p_continuous = min(p_continuous, new_p_continuous) + + return p_continuous + + +def compute_categorical_p_value( + categorical_labels: Optional[list[str]], + baseline_df: pd.DataFrame, + train_index: list[int], + test_index: list[int], +) -> float: + """ + Compute the minimum p-value for categorical variables between train and test splits. + + Parameters + ---------- + categorical_labels : Optional[List[str]] + List of categorical variable names. + baseline_df : pd.DataFrame + Dataframe containing the baseline data. + train_index : List[int] + Indices for the training set. + test_index : List[int] + + Returns + ------- + float + The minimum p-value across all categorical labels. + """ + + p_categorical = 1 + if categorical_labels: + for label in categorical_labels: + if len(baseline_df[label] != 1): + mapping = { + val: i for i, val in enumerate(np.unique(baseline_df[label])) + } + + tmp_train_values = baseline_df[label].loc[train_index].values.tolist() + tmp_test_values = baseline_df[label].loc[test_index].values.tolist() + + train_values = [mapping[val] for val in tmp_train_values] + test_values = [mapping[val] for val in tmp_test_values] + + new_p_categorical = _chi2_test(test_values, train_values) + + # Track the minimum p-value + p_categorical = min(p_categorical, new_p_categorical) + + return p_categorical + + +def write_continuous_stats( + tsv_path: Path, + continuous_labels: Optional[list[str]], + test_df: pd.DataFrame, + train_df: pd.DataFrame, + subset_name: str, +): + """ + Write continuous statistics (mean, std) to a TSV file. + + Parameters + ---------- + tsv_path : Path + Path to save the output TSV file. + continuous_labels : Optional[List[str]] + List of continuous variable names. + test_df : pd.DataFrame + Test dataset. + train_df : pd.DataFrame + Train dataset. + subset_name : str + Name of the test subset. + """ + + if not continuous_labels: + return + + data = [ + (label, "mean", train_df[label].mean(), test_df[label].mean()) + for label in continuous_labels + ] + [ + (label, "std", train_df[label].std(), test_df[label].std()) + for label in continuous_labels + ] + + df_stats_continuous = pd.DataFrame( + data, columns=["label", "statistic", "train", subset_name] + ) + df_stats_continuous.to_csv(tsv_path, sep="\t", index=False) + + +def write_categorical_stats( + tsv_path: Path, + categorical_labels: Optional[list[str]], + test_df: pd.DataFrame, + train_df: pd.DataFrame, + baseline_df: pd.DataFrame, + subset_name: str, +): + """ + Write categorical statistics (proportion, count) to a TSV file. + + Parameters + ---------- + tsv_path : Path + Path to save the output TSV file. + categorical_labels : Optional[List[str]] + List of categorical variable names. + test_df : pd.DataFrame + Test dataset. + train_df : pd.DataFrame + Train dataset. + baseline_df : pd.DataFrame + Baseline dataset (reference for all unique values). + subset_name : str + Name + + """ + + if not categorical_labels: + return + + data = [] + for label in categorical_labels: + unique_values = baseline_df[label].unique() + for val in unique_values: + test_count = (test_df[label] == val).sum() + train_count = (train_df[label] == val).sum() + + test_proportion = test_count / len(test_df) + train_proportion = train_count / len(train_df) + + data.append((label, val, "proportion", train_proportion, test_proportion)) + data.append((label, val, "count", train_count, test_count)) + + df_stats_categorical = pd.DataFrame( + data, columns=["label", "value", "statistic", "train", subset_name] + ) + df_stats_categorical.to_csv(tsv_path, sep="\t", index=False) diff --git a/clinicadl/splitter/make_splits/utils.py b/clinicadl/splitter/make_splits/utils.py new file mode 100644 index 000000000..687066790 --- /dev/null +++ b/clinicadl/splitter/make_splits/utils.py @@ -0,0 +1,78 @@ +from pathlib import Path +from typing import Optional + +import pandas as pd + +from clinicadl.tsvtools.tsvtools_utils import retrieve_longitudinal + + +def _write_to_csv(df: pd.DataFrame, file_path: Path) -> None: + """ + Save a DataFrame to a TSV file, ensuring the file does not already exist. + + Parameters + ---------- + df : pd.DataFrame + DataFrame to save. + file_path : Path + Path to the destination TSV file. + + Raises + ------ + FileExistsError + If the file already exists at the specified path. + """ + if file_path.exists(): + raise FileExistsError( + f"File {file_path} already exists. Operation aborted to prevent overwriting." + ) + + # Reset index for consistency and save as a TSV file + df.reset_index(drop=True, inplace=True) + df.to_csv(file_path, sep="\t", index=False) + + +def write_to_csv( + df: pd.DataFrame, + split_dir: Path, + all_df: Optional[pd.DataFrame] = None, + subset_name: str = "train", + longitudinal: bool = True, +) -> None: + """ + Save baseline and longitudinal splits of a DataFrame to TSV files. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the subset (e.g., train/test/validation) to save. + split_dir : Path + Directory where the TSV files will be saved. + all_df : Optional[pd.DataFrame], optional + Full dataset including all sessions, used to retrieve longitudinal data. + subset_name : str, default="train" + Name of the subset (e.g., "train", "test", etc.) used in the output filenames. + longitudinal : bool, default=True + Whether to generate and save the longitudinal data subset. + + Raises + ------ + FileExistsError + If any of the output files already exist in the specified directory. + ValueError + If `longitudinal` is True but `all_df` is None, as longitudinal data cannot be generated. + """ + # Save the baseline data + baseline_file = split_dir / f"{subset_name}_baseline.tsv" + _write_to_csv(df, baseline_file) + + if longitudinal: + if all_df is None: + raise ValueError( + "The full dataset (`all_df`) must be provided to generate longitudinal data." + ) + + # Retrieve longitudinal data and save it + longitudinal_file = split_dir / f"{subset_name}.tsv" + long_df = retrieve_longitudinal(df, all_df) + _write_to_csv(long_df, longitudinal_file) diff --git a/clinicadl/splitter/old_splitter.py b/clinicadl/splitter/old_splitter.py deleted file mode 100644 index d39b14a5b..000000000 --- a/clinicadl/splitter/old_splitter.py +++ /dev/null @@ -1,237 +0,0 @@ -import abc -import shutil -from logging import getLogger -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import pandas as pd - -from clinicadl.splitter.config import SplitterConfig -from clinicadl.utils import cluster -from clinicadl.utils.exceptions import MAPSError - -logger = getLogger("clinicadl.split_manager") - - -class Splitter: - def __init__( - self, - config: SplitterConfig, - # split_list: Optional[List[int]] = None, - ): - """_summary_ - - Parameters - ---------- - data_config : DataConfig - _description_ - validation_config : ValidationConfig - _description_ - split_list : Optional[List[int]] (optional, default=None) - _description_ - - """ - self.config = config - # self.config.split.split = split_list - - # self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? - - def max_length(self) -> int: - """Maximum number of splits""" - return self.config.split.n_splits - - def __len__(self): - if not self.config.split.split: - return self.config.split.n_splits - else: - return len(self.config.split.split) - - @property - def allowed_splits_list(self): - """ - List of possible splits if no restriction was applied - - Returns: - list[int]: list of all possible splits - """ - return [i for i in range(self.config.split.n_splits)] - - def __getitem__(self, item) -> Dict: - """ - Returns a dictionary of DataFrames with train and validation data. - - Args: - item (int): Index of the split wanted. - Returns: - Dict[str:pd.DataFrame]: dictionary with two keys (train and validation). - """ - self._check_item(item) - - if self.config.data.multi_cohort: - tsv_df = pd.read_csv(self.config.split.tsv_path, sep="\t") - train_df = pd.DataFrame() - valid_df = pd.DataFrame() - found_diagnoses = set() - for idx in range(len(tsv_df)): - cohort_name = tsv_df.at[idx, "cohort"] - cohort_path = Path(tsv_df.at[idx, "path"]) - cohort_diagnoses = ( - tsv_df.at[idx, "diagnoses"].replace(" ", "").split(",") - ) - if bool(set(cohort_diagnoses) & set(self.config.data.diagnoses)): - target_diagnoses = list( - set(cohort_diagnoses) & set(self.config.data.diagnoses) - ) - - cohort_train_df, cohort_valid_df = self.concatenate_diagnoses( - item, cohort_path=cohort_path, cohort_diagnoses=target_diagnoses - ) - cohort_train_df["cohort"] = cohort_name - cohort_valid_df["cohort"] = cohort_name - train_df = pd.concat([train_df, cohort_train_df]) - valid_df = pd.concat([valid_df, cohort_valid_df]) - found_diagnoses = found_diagnoses | ( - set(cohort_diagnoses) & set(self.config.data.diagnoses) - ) - - if found_diagnoses != set(self.config.data.diagnoses): - raise ValueError( - f"The diagnoses found in the multi cohort dataset {found_diagnoses} " - f"do not correspond to the diagnoses wanted {set(self.config.data.diagnoses)}." - ) - train_df.reset_index(inplace=True, drop=True) - valid_df.reset_index(inplace=True, drop=True) - else: - train_df, valid_df = self.concatenate_diagnoses(item) - train_df["cohort"] = "single" - valid_df["cohort"] = "single" - - return { - "train": train_df, - "validation": valid_df, - } - - @staticmethod - def get_dataframe_from_tsv_path(tsv_path: Path) -> pd.DataFrame: - df = pd.read_csv(tsv_path, sep="\t") - list_columns = df.columns.values - - if ( - "diagnosis" not in list_columns - # or "age" not in list_columns - # or "sex" not in list_columns - ): - parents_path = tsv_path.resolve().parent - labels_path = parents_path / "labels.tsv" - while ( - not labels_path.is_file() - and ((parents_path / "kfold.json").is_file()) - or (parents_path / "split.json").is_file() - ): - parents_path = parents_path.parent - try: - labels_df = pd.read_csv(labels_path, sep="\t") - df = pd.merge( - df, - labels_df, - how="inner", - on=["participant_id", "session_id"], - ) - except Exception: - pass - return df - - @staticmethod - def load_data(tsv_path: Path, cohort_diagnoses: List[str]) -> pd.DataFrame: - df = Splitter.get_dataframe_from_tsv_path(tsv_path) - df = df[df.diagnosis.isin((cohort_diagnoses))] - df.reset_index(inplace=True, drop=True) - return df - - def concatenate_diagnoses( - self, - split, - cohort_path: Optional[Path] = None, - cohort_diagnoses: Optional[List[str]] = None, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Concatenated the diagnoses needed to form the train and validation sets.""" - - if cohort_diagnoses is None: - cohort_diagnoses = list(self.config.data.diagnoses) - - tmp_cohort_path = ( - cohort_path if cohort_path is not None else self.config.split.tsv_path - ) - train_path, valid_path = self._get_tsv_paths( - tmp_cohort_path, - split, - ) - - logger.debug(f"Training data loaded at {train_path}") - if self.config.data.baseline: - train_path = train_path / "train_baseline.tsv" - else: - train_path = train_path / "train.tsv" - train_df = self.load_data(train_path, cohort_diagnoses) - - logger.debug(f"Validation data loaded at {valid_path}") - if self.config.validation.valid_longitudinal: - valid_path = valid_path / "validation.tsv" - else: - valid_path = valid_path / "validation_baseline.tsv" - valid_df = self.load_data(valid_path, cohort_diagnoses) - - return train_df, valid_df - - def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: - """ - Computes the paths to the TSV files needed depending on the split structure. - - Args: - cohort_path (str): path to the split structure of a cohort. - split (int): Index of the split. - Returns: - train_path (str): path to the directory containing training data. - valid_path (str): path to the directory containing validation data. - """ - if args is not None: - for split in args: - train_path = cohort_path / f"split-{split}" - valid_path = cohort_path / f"split-{split}" - return train_path, valid_path - else: - train_path = cohort_path - valid_path = cohort_path - return train_path, valid_path - - def split_iterator(self): - """Returns an iterable to iterate on all splits wanted.""" - - if not self.config.split.split: - return range(self.config.split.n_splits) - else: - return self.config.split.split - - def _check_item(self, item): - if item not in self.allowed_splits_list: - raise IndexError( - f"Split index {item} out of allowed splits {self.allowed_splits_list}." - ) - - def check_split_list(self, maps_path, overwrite): - existing_splits = [] - for split in self.split_iterator(): - split_path = maps_path / f"split-{split}" - if split_path.is_dir(): - if overwrite: - if cluster.master: - shutil.rmtree(split_path) - else: - existing_splits.append(split) - - if len(existing_splits) > 0: - raise MAPSError( - f"Splits {existing_splits} already exist. Please " - f"specify a list of splits not intersecting the previous list, " - f"or use overwrite to erase previously trained splits." - ) diff --git a/clinicadl/splitter/split.py b/clinicadl/splitter/split.py index 72c4f9d82..ff9ec21f8 100644 --- a/clinicadl/splitter/split.py +++ b/clinicadl/splitter/split.py @@ -1,18 +1,165 @@ from pathlib import Path +from typing import Optional -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.splitter.kfold import Split +from pydantic import NonNegativeInt +from torch.utils.data import DataLoader +from clinicadl.dataset.dataloader import DataLoaderConfig +from clinicadl.dataset.dataloader.defaults import ( + BATCH_SIZE, + DP_DEGREE, + DROP_LAST, + NUM_WORKERS, + PIN_MEMORY, + PREFETCH_FACTOR, + RANK, + SAMPLING_WEIGHTS, + SHUFFLE, +) +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.utils.config import ClinicaDLConfig -def split_tsv(sub_ses_tsv: Path) -> Path: - """TO COMPLETE""" - split_dir = Path("") - return split_dir +class Split(ClinicaDLConfig): + """Dataclass that contains all the useful info on the split.""" + index: NonNegativeInt + split_dir: Path + train_dataset: CapsDataset + val_dataset: CapsDataset + train_loader: Optional[DataLoader] = None + val_loader: Optional[DataLoader] = None + train_loader_config: Optional[DataLoaderConfig] = None + val_loader_config: Optional[DataLoaderConfig] = None -def get_single_split( - n_subject_validation: int, caps_dataset: CapsDataset, manager: ExperimentManager -) -> Split: - pass + def build_train_loader( + self, + dataloader_config: Optional[DataLoaderConfig] = None, + *, + batch_size: int = BATCH_SIZE, + sampling_weights: Optional[str] = SAMPLING_WEIGHTS, + shuffle: bool = SHUFFLE, + drop_last: bool = DROP_LAST, + num_workers: int = NUM_WORKERS, + prefetch_factor: Optional[int] = PREFETCH_FACTOR, + pin_memory: bool = PIN_MEMORY, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> None: + """ + Build a train loader for the split. + + Parameters + ---------- + dataloader_config : Optional[DataLoaderConfig] (optional, default=None) + Pre-configured DataLoader configuration. + batch_size : int (optional, default=1) + Batch size for the DataLoader (used if `dataloader_config` is not provided). + sampling_weights : Optional[str] (optional, default=None) + Name of the column in the dataframe of the CapsDatasets where to find the sampling + weights (used if `dataloader_config` is not provided). + shuffle : bool (optional, default=False) + Whether to shuffle the data (used if `dataloader_config` is not provided). + drop_last : bool (optional, default=False) + Whether to drop the last incomplete batch (used if `dataloader_config` is not provided). + num_workers : int (optional, default=0) + Number of workers for data loading (used if `dataloader_config` is not provided). + prefetch_factor : Optional[int] (optional, default=None) + Prefetch factor if num_workers is not 0 (used if `dataloader_config` is not provided). + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Raises + ------ + ValueError + If one of 'dp_degree' and 'rank' is None but the other is not None. + KeyError + If 'sampling_weights' is passed but there is no such column in the dataframe + of the train dataset. + """ + if dataloader_config: + self.train_loader_config = dataloader_config + else: + self.train_loader_config = DataLoaderConfig( + batch_size=batch_size, + sampling_weights=sampling_weights, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.train_loader = self.train_loader_config.get_dataloader( + dataset=self.train_dataset, + dp_degree=dp_degree, + rank=rank, + ) + + def build_val_loader( + self, + dataloader_config: Optional[DataLoaderConfig] = None, + *, + batch_size: int = BATCH_SIZE, + sampling_weights: Optional[str] = SAMPLING_WEIGHTS, + shuffle: bool = SHUFFLE, + drop_last: bool = DROP_LAST, + num_workers: int = NUM_WORKERS, + prefetch_factor: Optional[int] = PREFETCH_FACTOR, + pin_memory: bool = PIN_MEMORY, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> None: + """ + Build a validation loader for the split. + + Parameters + ---------- + dataloader_config : Optional[DataLoaderConfig] (optional, default=None) + Pre-configured DataLoader configuration. + batch_size : int (optional, default=1) + Batch size for the DataLoader (used if `dataloader_config` is not provided). + sampling_weights : Optional[str] (optional, default=None) + Name of the column in the dataframe of the CapsDatasets where to find the sampling + weights (used if `dataloader_config` is not provided). + shuffle : bool (optional, default=False) + Whether to shuffle the data (used if `dataloader_config` is not provided). + drop_last : bool (optional, default=False) + Whether to drop the last incomplete batch (used if `dataloader_config` is not provided). + num_workers : int (optional, default=0) + Number of workers for data loading (used if `dataloader_config` is not provided). + prefetch_factor : Optional[int] (optional, default=None) + Prefetch factor if num_workers is not 0 (used if `dataloader_config` is not provided). + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Raises + ------ + ValueError + If one of 'dp_degree' and 'rank' is None but the other is not None. + KeyError + If 'sampling_weights' is passed but there is no such column in the dataframe + of the validation dataset. + """ + if dataloader_config: + self.val_loader_config = dataloader_config + else: + self.val_loader_config = DataLoaderConfig( + batch_size=batch_size, + sampling_weights=sampling_weights, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.val_loader = self.val_loader_config.get_dataloader( + dataset=self.val_dataset, + dp_degree=dp_degree, + rank=rank, + ) diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index 3e0f09388..42a6c6a49 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -1,6 +1,119 @@ +from copy import copy from pathlib import Path from typing import List +import numpy as np +import pandas as pd + +from clinicadl.tsvtools.tsvtools_utils import first_session +from clinicadl.utils.exceptions import ClinicaDLTSVError + + +def extract_baseline(diagnosis_df, set_index=True): + from copy import deepcopy + + if set_index: + all_df = deepcopy(diagnosis_df) + all_df.set_index(["participant_id", "session_id"], inplace=True) + else: + all_df = deepcopy(diagnosis_df) + + result_df = pd.DataFrame() + for subject, subject_df in all_df.groupby(level=0): + if subject != "participant_id": + baseline = first_session(subject_df) + + subject_baseline_df = pd.DataFrame( + data=[ + [subject, baseline] + subject_df.loc[(subject, baseline)].tolist() + ], + columns=["participant_id", "session_id"] + + subject_df.columns.values.tolist(), + ) + result_df = pd.concat([result_df, subject_baseline_df]) + + result_df.reset_index(inplace=True, drop=True) + return result_df + + +def chi2(x_test, x_train): + from scipy.stats import chisquare + + # Look for chi2 computation + total_categories = np.concatenate([x_test, x_train]) + unique_categories = np.unique(total_categories) + f_obs = [(x_test == category).sum() / len(x_test) for category in unique_categories] + f_exp = [ + (x_train == category).sum() / len(x_train) for category in unique_categories + ] + T, p = chisquare(f_obs, f_exp) + + return T, p + + +def add_demographics(df, demographics_df, diagnosis) -> pd.DataFrame: + out_df = pd.DataFrame() + tmp_demo_df = copy(demographics_df) + tmp_demo_df.reset_index(inplace=True) + for idx in df.index.values: + participant = df.loc[idx, "participant_id"] + session = df.loc[idx, "session_id"] + row_df = tmp_demo_df[ + (tmp_demo_df.participant_id == participant) + & (tmp_demo_df.session_id == session) + ] + out_df = pd.concat([out_df, row_df]) + out_df.reset_index(inplace=True, drop=True) + out_df.diagnosis = [diagnosis] * len(out_df) + return out_df + + +def remove_unicity(values_list): + """Count the values of each class and label all the classes with only one label under the same label.""" + unique_classes, counts = np.unique(values_list, return_counts=True) + one_sub_classes = unique_classes[(counts == 1)] + for class_element in one_sub_classes: + values_list[values_list.index(class_element)] = unique_classes.min() + + return values_list + + +def category_conversion(values_list) -> List[int]: + values_np = np.array(values_list) + unique_classes = np.unique(values_np) + for index, unique_class in enumerate(unique_classes): + values_np[values_np == unique_class] = index + 1 + + return values_np.astype(int).tolist() + + +def find_label(labels_list, target_label): + if target_label in labels_list: + return target_label + else: + min_length = np.inf + found_label = None + for label in labels_list: + if target_label.lower() in label.lower() and min_length > len(label): + min_length = len(label) + found_label = label + if found_label is None: + raise ClinicaDLTSVError( + f"No label was found in {labels_list} for target label {target_label}." + ) + + return found_label + + +def retrieve_longitudinal(df, diagnosis_df): + final_df = pd.DataFrame() + for idx in df.index.values: + subject = df.loc[idx, "participant_id"] + row_df = diagnosis_df[diagnosis_df.participant_id == subject] + final_df = pd.concat([final_df, row_df]) + + return final_df + def find_splits(maps_path: Path) -> List[int]: """Find which splits that were trained in the MAPS.""" diff --git a/clinicadl/splitter/splitter/__init__.py b/clinicadl/splitter/splitter/__init__.py new file mode 100644 index 000000000..f5eec37c5 --- /dev/null +++ b/clinicadl/splitter/splitter/__init__.py @@ -0,0 +1,2 @@ +from .kfold import KFold +from .single_split import SingleSplit diff --git a/clinicadl/splitter/splitter/kfold.py b/clinicadl/splitter/splitter/kfold.py new file mode 100644 index 000000000..d4a306508 --- /dev/null +++ b/clinicadl/splitter/splitter/kfold.py @@ -0,0 +1,108 @@ +from pathlib import Path +from typing import Generator, List, Optional, Sequence, Union + +from pydantic import PositiveInt + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) + + +class KFoldConfig(SplitterConfig): + """ + Configuration for K-Fold cross-validation splits. + """ + + json_name: str = "kfold_config.json" + subset_name: str = "validation" + n_splits: PositiveInt = 5 + stratification: Union[str, bool] = False + + @property + def pattern(self): + return f"{self.n_splits}_fold" + + +class KFold(Splitter): + """ + Handles K-Fold cross-validation with optional stratification and demographic balancing. + Allows saving, reading, and iterating over splits for reproducibility. + """ + + def __init__(self, split_dir: Path): + """ + Initialize KFold with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + + super().__init__(split_dir=split_dir) + + def _init_config(self, **args): + self.config: KFoldConfig = KFoldConfig(**args) + + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + return [ + self._read_split(self.split_dir / f"split-{i}") + for i in range(self.config.n_splits) + ] + + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Generator[Split, None, None]: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + if not self.config or self.config.n_splits is None: + raise ValueError( + "No splits found, you must first run the function 'make_splits' to split your dataset, " + "or make sure you use a working split dir ." + ) + + self.check_dataset_and_tsv_consistency(dataset) + + if splits is None: + splits = list(range(self.config.n_splits)) + + for split in splits: + if split not in range(self.config.n_splits): + raise ValueError( + f"Split-{split} doesn't exist. There are {self.config.n_splits} splits, numbered from 0 to {self.config.n_splits-1}." + ) + yield self._get_split(split_id=split, dataset=dataset) diff --git a/clinicadl/splitter/splitter/single_split.py b/clinicadl/splitter/splitter/single_split.py new file mode 100644 index 000000000..daee8fd9f --- /dev/null +++ b/clinicadl/splitter/splitter/single_split.py @@ -0,0 +1,96 @@ +from pathlib import Path +from typing import List, Optional, Sequence, Union + +from pydantic import PositiveInt, field_validator + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) + + +class SingleSplitConfig(SplitterConfig): + json_name: str = "single_split_config.json" + subset_name: str = "test" + stratification: Union[List[str], bool] = False + n_test: PositiveInt = 100 + p_categorical_threshold: float = 0.80 + p_continuous_threshold: float = 0.80 + + @property + def pattern(self) -> str: + return "split" + + @field_validator("p_categorical_threshold", "p_continuous_threshold", mode="before") + @classmethod + def validate_thresholds(cls, value: Union[float, int]) -> float: + if not (0 <= value <= 1): + raise ValueError(f"Threshold must be between 0 and 1, got {value}") + return value + + +class SingleSplit(Splitter): + def __init__(self, split_dir: Path): + """ + Initialize Split with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + super().__init__(split_dir=split_dir) + + def _init_config(self, **args): + self.config = SingleSplitConfig(**args) + + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + return [self._read_split(self.split_dir)] + + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Split: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + if not self.config: + raise ValueError( + "No splits found, you must first run the function 'make_splits' to split your dataset, " + "or make sure you use a working split dir ." + ) + + self.check_dataset_and_tsv_consistency(dataset) + + return self._get_split(dataset) diff --git a/clinicadl/splitter/splitter/splitter.py b/clinicadl/splitter/splitter/splitter.py new file mode 100644 index 000000000..ae7f65ca0 --- /dev/null +++ b/clinicadl/splitter/splitter/splitter.py @@ -0,0 +1,256 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generator, List, Optional, Sequence, Tuple, Union + +import pandas as pd +from pydantic import ( + computed_field, + field_validator, +) + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.utils.config import ClinicaDLConfig +from clinicadl.utils.exceptions import ClinicaDLTSVError +from clinicadl.utils.iotools.utils import path_encoder + + +class SubjectsSessionsSplit(ClinicaDLConfig): + """ + Dataclass to store train and validation splits for subjects and sessions. + """ + + train: pd.DataFrame + validation: pd.DataFrame + + @computed_field + @property + def train_val_df(self) -> pd.DataFrame: + return pd.concat([self.train, self.validation], ignore_index=True) + + +class SplitterConfig(ClinicaDLConfig): + json_name: str + split_dir: Path + subset_name: str + stratification: Union[str, List[str], bool] = False + valid_longitudinal: bool = False + + @field_validator("split_dir", mode="after") + @classmethod + def validate_split_dir(cls, v): + if not isinstance(v, Path): + v = Path(v) + if v and not v.is_dir(): + v.mkdir(parents=True, exist_ok=True) + return v + + def _check_split_dir(self): + split_numero = 1 + folder_name = self.pattern + while (self.split_dir / folder_name).is_dir(): + split_numero += 1 + folder_name = f"{self.pattern}_{split_numero}" + + self.split_dir = self.split_dir / folder_name + + @property + @abstractmethod + def pattern(self) -> str: + pass + + def _write_json(self) -> None: + """ + Save KFold configuration to JSON. + """ + if not self.split_dir: + raise ValueError( + "No split directory specified, use the method 'write' to save your splits." + ) + + out_json_file = self.split_dir / self.json_name + if out_json_file.is_file(): + raise FileExistsError( + f"File {out_json_file} already exists, your splits may have already been written." + ) + + with out_json_file.open(mode="w") as json_file: + json.dump( + self.model_dump(), + json_file, + skipkeys=True, + indent=4, + default=path_encoder, + ) + + +class Splitter(ABC): + def __init__(self, split_dir: Path): + """ + Initialize Split with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + split_dir = Path(split_dir) + + if not split_dir.is_dir(): + raise FileNotFoundError(f"No such directory: {split_dir}") + + self.split_dir = split_dir + self._init_config(**self._read_json()) + self.subjects_sessions_split = self._read_splits() + + @abstractmethod + def _init_config(self, **args): + self.config: SplitterConfig + + def _read_json(self): + """ + Load KFold configuration from a JSON file. + + Parameters + ---------- + split_dir : Path + Directory containing the JSON configuration file. + + Returns + ------- + KFoldConfig + The configuration object loaded from the JSON file. + """ + + json_file = [json for json in self.split_dir.glob("*.json")] + + if len(json_file) > 1: + raise ValueError( + f"Multiple JSON files found in {self.split_dir}, please remove or rename them." + ) + + elif len(json_file) == 0: + raise FileNotFoundError(f"No JSON file found in {self.split_dir}") + + if not json_file[0].is_file(): + raise FileNotFoundError(f"No such file: {json_file}") + + with json_file[0].open(mode="r") as file: + dict_ = json.load(file) + + return dict_ + + def _read_split(self, split_path: Path) -> SubjectsSessionsSplit: + """ + Load a single split's train and validation sets from files. + + Parameters + ---------- + split_dir : Path + Directory containing split data. + split_number : int + The split index to load. + + Returns + ------- + SubjectsSessionsSplit + Object containing train and validation sets as DataFrames. + """ + + if not split_path.is_dir(): + raise FileNotFoundError(f"No such directory: {split_path}") + + try: + train = pd.read_csv(split_path / "train_baseline.tsv", sep="\t") + validation = pd.read_csv( + split_path / f"{self.config.subset_name}_baseline.tsv", sep="\t" + ) # type: ignore + + except FileNotFoundError as exc: + raise FileNotFoundError( + f"One or more of the required files are missing: 'train_baseline.tsv', '{self.config.subset_name}_baseline.tsv'" + ) from exc # type: ignore + + return SubjectsSessionsSplit( + train=train, + validation=validation, + ) + + @abstractmethod + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + + def check_dataset_and_tsv_consistency(self, dataset: CapsDataset): + df1 = self.subjects_sessions_split[0].train_val_df + df2 = dataset.df + pairs_df1 = set(zip(df1["participant_id"], df1["session_id"])) + pairs_df2 = set(zip(df2["participant_id"], df2["session_id"])) + + # Vérification que toutes les paires de df1 sont dans df2 + if not pairs_df1.issubset(pairs_df2): + raise ClinicaDLTSVError( + "Not all pairs of participants and sessions from the TSV file are present in the dataset." + "Please check the TSV file and make sure all participants and sessions are unique." + ) + + @abstractmethod + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Union[Split, Generator[Split, None, None]]: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + def _get_split( + self, + dataset: CapsDataset, + split_id: int = 0, + ) -> Split: + """ + Retrieve a single dataset split. + + Parameters + ---------- + split_id : int + Index of the split to retrieve. + + Returns + ------- + Split + Object containing train and validation datasets for the specified split. + """ + subjects_sessions = self.subjects_sessions_split[split_id] + return Split( + index=split_id, + split_dir=self.split_dir, + train_dataset=dataset.subset(subjects_sessions.train), + val_dataset=dataset.subset(subjects_sessions.validation), + ) diff --git a/clinicadl/splitter/test.py b/clinicadl/splitter/test.py new file mode 100644 index 000000000..83010184e --- /dev/null +++ b/clinicadl/splitter/test.py @@ -0,0 +1,168 @@ +from pathlib import Path + +import pandas as pd +import torchio.transforms as transforms + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.preprocessing import PreprocessingT1 +from clinicadl.splitter import make_kfold, make_split +from clinicadl.splitter.dataloader import DataLoaderConfig +from clinicadl.splitter.splitter import KFold, SingleSplit +from clinicadl.transforms.extraction import Image, Patch, Slice +from clinicadl.transforms.transforms import Transforms +from clinicadl.tsvtools.get_metadata.get_metadata import get_metadata + +# maps_path = Path("/") +# manager = ExperimentManager(maps_path, overwrite=False) + + +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +sub_ses_all = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects.tsv") + +# df = get_metadata(sub_ses_t1, sub_ses_all) + +splir_dir = make_split( + sub_ses_t1, + output_dir=Path( + "/Users/camille.brianceau/aramis/CLINICADL/clinicadl/tests/unittests/ressources/caps_example/split_test" + ), + subset_name="test", + stratification=["age", "sex", "test", "diagnosis"], + n_test=0.2, +) +print(splir_dir) + + +train_path = splir_dir / "train_baseline.tsv" +test_path = splir_dir / "test_baseline.tsv" + +train_df = pd.read_csv(train_path, sep="\t") +test_df = pd.read_csv(test_path, sep="\t") + +print("train age mean", train_df["age"].mean()) +print("test age mean", test_df["age"].mean()) +print("/n") +print("train age std", train_df["age"].std()) +print("test age std", test_df["age"].std()) +print("/n") +print("train test mean", train_df["test"].mean()) +print("test test mean", test_df["test"].mean()) +print("/n") +print("train test std", train_df["test"].std()) +print("test test std", test_df["test"].std()) + +print("/n") +print( + "train diagnosis count AD", + len(train_df[train_df["diagnosis"] == "AD"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "AD"]) / len(train_df), +) +print( + "test diagnosis count AD", + len(test_df[test_df["diagnosis"] == "AD"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "AD"]) / len(test_df), +) +print("/n") +print( + "train diagnosis count MCI", + len(train_df[train_df["diagnosis"] == "MCI"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "MCI"]) / len(train_df), +) +print( + "test diagnosis count MCI", + len(test_df[test_df["diagnosis"] == "MCI"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "MCI"]) / len(test_df), +) +print("/n") +print( + "train diagnosis count CN", + len(train_df[train_df["diagnosis"] == "CN"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "CN"]) / len(train_df), +) +print( + "test diagnosis count CN", + len(test_df[test_df["diagnosis"] == "CN"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "CN"]) / len(test_df), +) + +print("/n") +print( + "train sex count F", + len(train_df[train_df["sex"] == "F"]), + "/", + len(train_df), + len(train_df[train_df["sex"] == "F"]) / len(train_df), +) +print( + "test sex count F", + len(test_df[test_df["sex"] == "F"]), + "/", + len(test_df), + len(test_df[test_df["sex"] == "F"]) / len(test_df), +) +print("/n") +print( + "train sex count M", + len(train_df[train_df["sex"] == "M"]), + "/", + len(train_df), + len(train_df[train_df["sex"] == "M"]) / len(train_df), +) +print( + "test sex count M", + len(test_df[test_df["sex"] == "M"]), + "/", + len(test_df), + len(test_df[test_df["sex"] == "M"]) / len(test_df), +) + + +fold_dir = make_kfold(train_path, stratification="sex", n_splits=2) + +print(fold_dir) + +caps_directory = Path("/Users/camille.brianceau/aramis/CLINICADL/caps") +preprocessing_t1 = PreprocessingT1() +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=train_path, + preprocessing=preprocessing_t1, + transforms=transforms_image, +) +print(dataset_t1_image.__str__()) +dataset_t1_image.prepare_data(n_proc=2) + +splitter = KFold(fold_dir) + +for split in splitter.get_splits(dataset=dataset_t1_image): + print(f"Split {split.index}:\n") + print(f"Train dataset: {split.train_dataset}") + print(f"describe: {split.train_dataset.describe()}") + print(f"elem per image: {split.train_dataset.elem_per_image}") + print("\n") + print(f"Validation dataset: {split.val_dataset}") + print(f"describe: {split.val_dataset.describe()}") + print(f"elem per image: {split.val_dataset.elem_per_image}") + + split.build_train_loader(num_workers=2) + split.build_val_loader(DataLoaderConfig(batch_size=2)) + print("/n") + print(f"Train loader: {split.train_loader}") + print(f"Validation loader: {split.val_loader}") diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index f89818576..538eccecd 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -18,7 +18,7 @@ from clinicadl.optim.config import OptimizationConfig from clinicadl.optim.early_stopping import EarlyStoppingConfig from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig @@ -44,7 +44,6 @@ class TrainConfig(BaseModel, ABC): model: NetworkConfig optimization: OptimizationConfig reproducibility: ReproducibilityConfig - split: SplitConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig validation: ValidationConfig @@ -86,7 +85,6 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.model.__dict__.update(config_dict) self.optimization.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) - self.split.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) self.validation.__dict__.update(config_dict) diff --git a/clinicadl/trainer/old_trainer.py b/clinicadl/trainer/old_trainer.py index 4d00c5206..aa3f82c70 100644 --- a/clinicadl/trainer/old_trainer.py +++ b/clinicadl/trainer/old_trainer.py @@ -34,8 +34,9 @@ from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.predictor.old_predictor import Predictor from clinicadl.predictor.config import PredictConfig -from clinicadl.splitter.old_splitter import Splitter -from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter.splitter import Splitter +from clinicadl.splitter.splitter.splitter import SplitterConfig + if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback diff --git a/clinicadl/tsvtools/get_metadata/get_metadata.py b/clinicadl/tsvtools/get_metadata/get_metadata.py index 5ba40a4f9..63cb474f9 100644 --- a/clinicadl/tsvtools/get_metadata/get_metadata.py +++ b/clinicadl/tsvtools/get_metadata/get_metadata.py @@ -67,3 +67,5 @@ def get_metadata( result_df.to_csv(data_tsv, sep="\t") logger.info(f"metadata were added in: {data_tsv}") + + return result_df diff --git a/clinicadl/tsvtools/split/split.py b/clinicadl/tsvtools/split/split.py index 6235ee2e8..836ee48a7 100644 --- a/clinicadl/tsvtools/split/split.py +++ b/clinicadl/tsvtools/split/split.py @@ -25,35 +25,6 @@ logger = getLogger("clinicadl.tsvtools.split") -def KStests(train_df, test_df, threshold=0.5): - pmin = 1 - column = "" - for col in train_df.columns: - if col == "session_id": - continue - _, pval = ks_2samp(train_df[col], test_df[col]) - if pval < pmin: - pmin = pval - column = col - return (pmin, column) - - -def shuffle_choice(df, n_shuffle=10): - p_min_max, n_col_min = 0, df.columns.size - - for i in range(n_shuffle): - train_df = df.sample(frac=0.75) - test_df = df.drop(train_df.index) - - p, col = KStests(train_df, test_df) - - if p > p_min_max: - p_min_max = p - best_train_df, best_test_df = train_df, test_df - - return (best_train_df, best_test_df, p_min_max) - - def KStests(train_df, test_df, threshold=0.5): pmin = 1 column = "" diff --git a/clinicadl/tsvtools/tsvtools_utils.py b/clinicadl/tsvtools/tsvtools_utils.py index caf842b3e..27291940e 100644 --- a/clinicadl/tsvtools/tsvtools_utils.py +++ b/clinicadl/tsvtools/tsvtools_utils.py @@ -3,6 +3,7 @@ from copy import copy from logging import getLogger from pathlib import Path +from typing import List import numpy as np import pandas as pd @@ -153,7 +154,7 @@ def remove_unicity(values_list): return values_list -def category_conversion(values_list): +def category_conversion(values_list) -> List[int]: values_np = np.array(values_list) unique_classes = np.unique(values_np) for index, unique_class in enumerate(unique_classes): diff --git a/tests/unittests/dataset/test_config.py b/tests/unittests/dataset/test_config.py index d7e51d598..4fa9588e9 100644 --- a/tests/unittests/dataset/test_config.py +++ b/tests/unittests/dataset/test_config.py @@ -1,6 +1,6 @@ import pytest -from clinicadl.dataset.config import DataConfig, FileType +from clinicadl.dataset.config import FileType from clinicadl.utils.enum import Preprocessing diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json b/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json new file mode 100644 index 000000000..6bf0d9492 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json @@ -0,0 +1,8 @@ +{ + "json_name": "kfold_config.json", + "split_dir": "../ressources/caps_example/split_test/split/2_fold", + "subset_name": "validation", + "stratification": "sex", + "valid_longitudinal": false, + "n_splits": 2 +} \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json b/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json new file mode 100644 index 000000000..a21941ccb --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json @@ -0,0 +1,15 @@ +{ + "json_name": "single_split_config.json", + "split_dir": "../ressources/caps_example/split_test/split", + "subset_name": "test", + "stratification": [ + "age", + "sex", + "test", + "diagnosis" + ], + "valid_longitudinal": false, + "n_test": 24, + "p_categorical_threshold": 0.5, + "p_continuous_threshold": 0.5 +} \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv b/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv new file mode 100644 index 000000000..2fbf24954 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv @@ -0,0 +1,11 @@ +label value statistic train test +sex M proportion 0.59375 0.625 +sex M count 57.0 15.0 +sex F proportion 0.40625 0.375 +sex F count 39.0 9.0 +diagnosis MCI proportion 0.2916666666666667 0.3333333333333333 +diagnosis MCI count 28.0 8.0 +diagnosis CN proportion 0.40625 0.375 +diagnosis CN count 39.0 9.0 +diagnosis AD proportion 0.3020833333333333 0.2916666666666667 +diagnosis AD count 29.0 7.0 diff --git a/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv b/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv new file mode 100644 index 000000000..6bb4a628d --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv @@ -0,0 +1,5 @@ +label statistic train test +age mean 63.416666666666664 61.083333333333336 +test mean 16.552083333333332 16.291666666666668 +age std 19.392709789569277 14.634434279010128 +test std 9.367210213984869 9.129403841037684 diff --git a/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv new file mode 100644 index 000000000..7370eeaa4 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv @@ -0,0 +1,25 @@ +participant_id session_id age sex test diagnosis +sub-163 ses-M000 83 M 16 MCI +sub-227 ses-M006 90 F 23 CN +sub-006 ses-M006 72 M 6 AD +sub-014 ses-M000 56 M 10 CN +sub-034 ses-M000 74 M 26 CN +sub-065 ses-M006 58 F 13 AD +sub-136 ses-M006 52 M 30 MCI +sub-165 ses-M006 58 F 13 AD +sub-037 ses-M006 56 F 31 MCI +sub-226 ses-M006 60 M 22 MCI +sub-113 ses-M000 40 M 9 CN +sub-076 ses-M006 49 M 3 AD +sub-157 ses-M006 77 F 18 AD +sub-237 ses-M006 56 F 31 MCI +sub-024 ses-M000 41 M 18 CN +sub-035 ses-M006 85 F 28 MCI +sub-026 ses-M006 60 M 22 MCI +sub-204 ses-M000 56 M 2 CN +sub-257 ses-M006 77 F 18 AD +sub-114 ses-M000 56 M 10 CN +sub-156 ses-M006 72 M 19 CN +sub-203 ses-M000 40 M 1 MCI +sub-265 ses-M006 58 F 13 AD +sub-013 ses-M000 40 M 9 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/train.tsv new file mode 100644 index 000000000..ac26a8982 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/train.tsv @@ -0,0 +1,155 @@ +participant_id session_id age sex test diagnosis +sub-275 ses-M006 75 F 5 CN +sub-275 ses-M018 96 M 4 MCI +sub-036 ses-M006 52 M 30 MCI +sub-146 ses-M006 73 M 27 MCI +sub-243 ses-M000 71 M 32 CN +sub-155 ses-M006 73 F 21 MCI +sub-155 ses-M018 48 M 20 CN +sub-164 ses-M000 86 M 15 AD +sub-164 ses-M054 91 F 14 MCI +sub-016 ses-M006 72 M 14 AD +sub-247 ses-M006 48 F 26 AD +sub-247 ses-M036 59 M 25 CN +sub-064 ses-M000 86 M 15 AD +sub-064 ses-M054 91 F 14 MCI +sub-206 ses-M006 72 M 6 AD +sub-053 ses-M000 52 M 24 AD +sub-043 ses-M000 71 M 32 CN +sub-007 ses-M006 41 F 7 MCI +sub-007 ses-M036 59 M 8 CN +sub-167 ses-M006 94 F 10 CN +sub-167 ses-M036 68 M 9 MCI +sub-103 ses-M000 40 M 1 MCI +sub-057 ses-M006 77 F 18 AD +sub-057 ses-M036 99 M 17 MCI +sub-177 ses-M006 23 F 2 CN +sub-177 ses-M036 45 M 1 MCI +sub-117 ses-M006 41 F 15 CN +sub-117 ses-M036 59 M 16 CN +sub-224 ses-M000 41 M 18 CN +sub-224 ses-M054 63 F 19 AD +sub-133 ses-M000 36 M 25 AD +sub-066 ses-M006 63 M 11 MCI +sub-124 ses-M000 41 M 18 CN +sub-124 ses-M054 63 F 19 AD +sub-235 ses-M006 85 F 28 MCI +sub-235 ses-M018 57 M 29 CN +sub-236 ses-M006 52 M 30 MCI +sub-253 ses-M000 52 M 24 AD +sub-143 ses-M000 71 M 32 CN +sub-017 ses-M006 41 F 15 CN +sub-017 ses-M036 59 M 16 CN +sub-213 ses-M000 40 M 9 CN +sub-154 ses-M000 38 M 23 CN +sub-154 ses-M054 47 F 22 MCI +sub-135 ses-M006 85 F 28 MCI +sub-135 ses-M018 57 M 29 CN +sub-223 ses-M000 69 M 17 AD +sub-045 ses-M006 92 F 29 MCI +sub-045 ses-M018 61 M 28 MCI +sub-123 ses-M000 69 M 17 AD +sub-125 ses-M006 57 F 20 MCI +sub-125 ses-M018 75 M 21 AD +sub-044 ses-M000 64 M 31 AD +sub-044 ses-M054 23 F 30 MCI +sub-054 ses-M000 38 M 23 CN +sub-054 ses-M054 47 F 22 MCI +sub-003 ses-M000 40 M 1 MCI +sub-126 ses-M006 60 M 22 MCI +sub-215 ses-M006 85 F 12 CN +sub-215 ses-M018 64 M 13 CN +sub-005 ses-M006 85 F 4 CN +sub-005 ses-M018 64 M 5 AD +sub-153 ses-M000 52 M 24 AD +sub-273 ses-M000 85 M 8 AD +sub-115 ses-M006 85 F 12 CN +sub-115 ses-M018 64 M 13 CN +sub-174 ses-M000 34 M 7 CN +sub-174 ses-M054 98 F 6 MCI +sub-075 ses-M006 75 F 5 CN +sub-075 ses-M018 96 M 4 MCI +sub-144 ses-M000 64 M 31 AD +sub-144 ses-M054 23 F 30 MCI +sub-233 ses-M000 36 M 25 AD +sub-077 ses-M006 23 F 2 CN +sub-077 ses-M036 45 M 1 MCI +sub-116 ses-M006 72 M 14 AD +sub-254 ses-M000 38 M 23 CN +sub-254 ses-M054 47 F 22 MCI +sub-246 ses-M006 73 M 27 MCI +sub-276 ses-M006 49 M 3 AD +sub-015 ses-M006 85 F 12 CN +sub-015 ses-M018 64 M 13 CN +sub-217 ses-M006 41 F 15 CN +sub-217 ses-M036 59 M 16 CN +sub-074 ses-M000 34 M 7 CN +sub-074 ses-M054 98 F 6 MCI +sub-137 ses-M006 56 F 31 MCI +sub-137 ses-M036 65 M 32 MCI +sub-004 ses-M000 56 M 2 CN +sub-004 ses-M054 75 F 3 MCI +sub-055 ses-M006 73 F 21 MCI +sub-055 ses-M018 48 M 20 CN +sub-176 ses-M006 49 M 3 AD +sub-234 ses-M000 74 M 26 CN +sub-234 ses-M054 72 F 27 CN +sub-267 ses-M006 94 F 10 CN +sub-267 ses-M036 68 M 9 MCI +sub-166 ses-M006 63 M 11 MCI +sub-244 ses-M000 64 M 31 AD +sub-244 ses-M054 23 F 30 MCI +sub-175 ses-M006 75 F 5 CN +sub-175 ses-M018 96 M 4 MCI +sub-127 ses-M006 90 F 23 CN +sub-127 ses-M036 59 M 24 CN +sub-214 ses-M000 56 M 10 CN +sub-214 ses-M054 75 F 11 AD +sub-173 ses-M000 85 M 8 AD +sub-255 ses-M006 73 F 21 MCI +sub-255 ses-M018 48 M 20 CN +sub-046 ses-M006 73 M 27 MCI +sub-264 ses-M000 86 M 15 AD +sub-264 ses-M054 91 F 14 MCI +sub-274 ses-M000 34 M 7 CN +sub-274 ses-M054 98 F 6 MCI +sub-256 ses-M006 72 M 19 CN +sub-205 ses-M006 85 F 4 CN +sub-205 ses-M018 64 M 5 AD +sub-145 ses-M006 92 F 29 MCI +sub-145 ses-M018 61 M 28 MCI +sub-104 ses-M000 56 M 2 CN +sub-104 ses-M054 75 F 3 MCI +sub-056 ses-M006 72 M 19 CN +sub-106 ses-M006 72 M 6 AD +sub-033 ses-M000 36 M 25 AD +sub-225 ses-M006 57 F 20 MCI +sub-225 ses-M018 75 M 21 AD +sub-263 ses-M000 83 M 16 MCI +sub-216 ses-M006 72 M 14 AD +sub-245 ses-M006 92 F 29 MCI +sub-245 ses-M018 61 M 28 MCI +sub-073 ses-M000 85 M 8 AD +sub-277 ses-M006 23 F 2 CN +sub-277 ses-M036 45 M 1 MCI +sub-107 ses-M006 41 F 7 MCI +sub-107 ses-M036 59 M 8 CN +sub-067 ses-M006 94 F 10 CN +sub-067 ses-M036 68 M 9 MCI +sub-266 ses-M006 63 M 11 MCI +sub-025 ses-M006 57 F 20 MCI +sub-025 ses-M018 75 M 21 AD +sub-023 ses-M000 69 M 17 AD +sub-063 ses-M000 83 M 16 MCI +sub-047 ses-M006 48 F 26 AD +sub-047 ses-M036 59 M 25 CN +sub-105 ses-M006 85 F 4 CN +sub-105 ses-M018 64 M 5 AD +sub-027 ses-M006 90 F 23 CN +sub-027 ses-M036 59 M 24 CN +sub-134 ses-M000 74 M 26 CN +sub-134 ses-M054 72 F 27 CN +sub-207 ses-M006 41 F 7 MCI +sub-207 ses-M036 59 M 8 CN +sub-147 ses-M006 48 F 26 AD +sub-147 ses-M036 59 M 25 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv new file mode 100644 index 000000000..9188b346b --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv @@ -0,0 +1,97 @@ +participant_id session_id age sex test diagnosis +sub-275 ses-M006 75 F 5 CN +sub-036 ses-M006 52 M 30 MCI +sub-146 ses-M006 73 M 27 MCI +sub-243 ses-M000 71 M 32 CN +sub-155 ses-M006 73 F 21 MCI +sub-164 ses-M000 86 M 15 AD +sub-016 ses-M006 72 M 14 AD +sub-247 ses-M006 48 F 26 AD +sub-064 ses-M000 86 M 15 AD +sub-206 ses-M006 72 M 6 AD +sub-053 ses-M000 52 M 24 AD +sub-043 ses-M000 71 M 32 CN +sub-007 ses-M006 41 F 7 MCI +sub-167 ses-M006 94 F 10 CN +sub-103 ses-M000 40 M 1 MCI +sub-057 ses-M006 77 F 18 AD +sub-177 ses-M006 23 F 2 CN +sub-117 ses-M006 41 F 15 CN +sub-224 ses-M000 41 M 18 CN +sub-133 ses-M000 36 M 25 AD +sub-066 ses-M006 63 M 11 MCI +sub-124 ses-M000 41 M 18 CN +sub-235 ses-M006 85 F 28 MCI +sub-236 ses-M006 52 M 30 MCI +sub-253 ses-M000 52 M 24 AD +sub-143 ses-M000 71 M 32 CN +sub-017 ses-M006 41 F 15 CN +sub-213 ses-M000 40 M 9 CN +sub-154 ses-M000 38 M 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-223 ses-M000 69 M 17 AD +sub-045 ses-M006 92 F 29 MCI +sub-123 ses-M000 69 M 17 AD +sub-125 ses-M006 57 F 20 MCI +sub-044 ses-M000 64 M 31 AD +sub-054 ses-M000 38 M 23 CN +sub-003 ses-M000 40 M 1 MCI +sub-126 ses-M006 60 M 22 MCI +sub-215 ses-M006 85 F 12 CN +sub-005 ses-M006 85 F 4 CN +sub-153 ses-M000 52 M 24 AD +sub-273 ses-M000 85 M 8 AD +sub-115 ses-M006 85 F 12 CN +sub-174 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-144 ses-M000 64 M 31 AD +sub-233 ses-M000 36 M 25 AD +sub-077 ses-M006 23 F 2 CN +sub-116 ses-M006 72 M 14 AD +sub-254 ses-M000 38 M 23 CN +sub-246 ses-M006 73 M 27 MCI +sub-276 ses-M006 49 M 3 AD +sub-015 ses-M006 85 F 12 CN +sub-217 ses-M006 41 F 15 CN +sub-074 ses-M000 34 M 7 CN +sub-137 ses-M006 56 F 31 MCI +sub-004 ses-M000 56 M 2 CN +sub-055 ses-M006 73 F 21 MCI +sub-176 ses-M006 49 M 3 AD +sub-234 ses-M000 74 M 26 CN +sub-267 ses-M006 94 F 10 CN +sub-166 ses-M006 63 M 11 MCI +sub-244 ses-M000 64 M 31 AD +sub-175 ses-M006 75 F 5 CN +sub-127 ses-M006 90 F 23 CN +sub-214 ses-M000 56 M 10 CN +sub-173 ses-M000 85 M 8 AD +sub-255 ses-M006 73 F 21 MCI +sub-046 ses-M006 73 M 27 MCI +sub-264 ses-M000 86 M 15 AD +sub-274 ses-M000 34 M 7 CN +sub-256 ses-M006 72 M 19 CN +sub-205 ses-M006 85 F 4 CN +sub-145 ses-M006 92 F 29 MCI +sub-104 ses-M000 56 M 2 CN +sub-056 ses-M006 72 M 19 CN +sub-106 ses-M006 72 M 6 AD +sub-033 ses-M000 36 M 25 AD +sub-225 ses-M006 57 F 20 MCI +sub-263 ses-M000 83 M 16 MCI +sub-216 ses-M006 72 M 14 AD +sub-245 ses-M006 92 F 29 MCI +sub-073 ses-M000 85 M 8 AD +sub-277 ses-M006 23 F 2 CN +sub-107 ses-M006 41 F 7 MCI +sub-067 ses-M006 94 F 10 CN +sub-266 ses-M006 63 M 11 MCI +sub-025 ses-M006 57 F 20 MCI +sub-023 ses-M000 69 M 17 AD +sub-063 ses-M000 83 M 16 MCI +sub-047 ses-M006 48 F 26 AD +sub-105 ses-M006 85 F 4 CN +sub-027 ses-M006 90 F 23 CN +sub-134 ses-M000 74 M 26 CN +sub-207 ses-M006 41 F 7 MCI +sub-147 ses-M006 48 F 26 AD diff --git a/tests/unittests/ressources/caps_example/subjects_false.tsv b/tests/unittests/ressources/caps_example/subjects_false.tsv new file mode 100644 index 000000000..136a18edc --- /dev/null +++ b/tests/unittests/ressources/caps_example/subjects_false.tsv @@ -0,0 +1,2 @@ +sub ses age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv index cd3c1d0b2..e505e36c3 100644 --- a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv +++ b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv @@ -3,11 +3,3 @@ sub-000 ses-M000 sub-000 ses-M006 sub-001 ses-M000 sub-001 ses-M018 -sub-OAS30010 ses-M000 -sub-OAS30011 ses-M000 -sub-OAS30011 ses-M054 -sub-OAS30012 ses-M006 -sub-OAS30012 ses-M018 -sub-OAS30013 ses-M006 -sub-OAS30014 ses-M006 -sub-OAS30014 ses-M036 diff --git a/tests/unittests/ressources/caps_example/subjects_t1.tsv b/tests/unittests/ressources/caps_example/subjects_t1.tsv new file mode 100644 index 000000000..60dfb62dd --- /dev/null +++ b/tests/unittests/ressources/caps_example/subjects_t1.tsv @@ -0,0 +1,65 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-004 ses-M054 75 F 3 MCI +sub-005 ses-M006 85 F 4 CN +sub-005 ses-M018 64 M 5 AD +sub-006 ses-M006 72 M 6 AD +sub-007 ses-M006 41 F 7 MCI +sub-007 ses-M036 59 M 8 CN +sub-013 ses-M000 40 M 9 CN +sub-014 ses-M000 56 M 10 CN +sub-014 ses-M054 75 F 11 AD +sub-015 ses-M006 85 F 12 CN +sub-015 ses-M018 64 M 13 CN +sub-016 ses-M006 72 M 14 AD +sub-017 ses-M006 41 F 15 CN +sub-017 ses-M036 59 M 16 CN +sub-023 ses-M000 69 M 17 AD +sub-024 ses-M000 41 M 18 CN +sub-024 ses-M054 63 F 19 AD +sub-025 ses-M006 57 F 20 MCI +sub-025 ses-M018 75 M 21 AD +sub-026 ses-M006 60 M 22 MCI +sub-027 ses-M006 90 F 23 CN +sub-027 ses-M036 59 M 24 CN +sub-033 ses-M000 36 M 25 AD +sub-034 ses-M000 74 M 26 CN +sub-034 ses-M054 72 F 27 CN +sub-035 ses-M006 85 F 28 MCI +sub-035 ses-M018 57 M 29 CN +sub-036 ses-M006 52 M 30 MCI +sub-037 ses-M006 56 F 31 MCI +sub-037 ses-M036 65 M 32 MCI +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-044 ses-M054 23 F 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-045 ses-M018 61 M 28 MCI +sub-046 ses-M006 73 M 27 MCI +sub-047 ses-M006 48 F 26 AD +sub-047 ses-M036 59 M 25 CN +sub-053 ses-M000 52 M 24 AD +sub-054 ses-M000 38 M 23 CN +sub-054 ses-M054 47 F 22 MCI +sub-055 ses-M006 73 F 21 MCI +sub-055 ses-M018 48 M 20 CN +sub-056 ses-M006 72 M 19 CN +sub-057 ses-M006 77 F 18 AD +sub-057 ses-M036 99 M 17 MCI +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-064 ses-M054 91 F 14 MCI +sub-065 ses-M006 58 F 13 AD +sub-065 ses-M018 74 M 12 MCI +sub-066 ses-M006 63 M 11 MCI +sub-067 ses-M006 94 F 10 CN +sub-067 ses-M036 68 M 9 MCI +sub-073 ses-M000 85 M 8 AD +sub-074 ses-M000 34 M 7 CN +sub-074 ses-M054 98 F 6 MCI +sub-075 ses-M006 75 F 5 CN +sub-075 ses-M018 96 M 4 MCI +sub-076 ses-M006 49 M 3 AD +sub-077 ses-M006 23 F 2 CN +sub-077 ses-M036 45 M 1 MCI diff --git a/tests/unittests/splitter/test_make_split.py b/tests/unittests/splitter/test_make_split.py new file mode 100644 index 000000000..3216f097f --- /dev/null +++ b/tests/unittests/splitter/test_make_split.py @@ -0,0 +1,195 @@ +import json +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError + +from clinicadl.splitter.make_splits import make_kfold, make_split +from clinicadl.utils.exceptions import ( + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) + + +def remove_non_empty_dir(dir_path: Path): + """ + Remove a non-empty directory using only pathlib. + + Parameters + ---------- + dir_path : Path + Path to the directory to remove. + """ + if dir_path.exists() and dir_path.is_dir(): + for item in dir_path.iterdir(): # Iterate through directory contents + if item.is_dir(): + remove_non_empty_dir(item) # Recursively remove subdirectories + else: + item.unlink() # Remove files + dir_path.rmdir() # Remove the now-empty directory + else: + print(f"{dir_path} does not exist or is not a directory.") + + +caps_dir = Path(__file__).parents[1] / "ressources" / "caps_example" + +sub_ses_t1 = caps_dir / "subjects_t1.tsv" +sub_ses_df = pd.read_csv(sub_ses_t1, sep="\t") + +split_dir = caps_dir / "split" +train_path = split_dir / "train.tsv" + + +def test_good_split(): + n_test = 15 + stratification = ["age", "sex", "test", "diagnosis"] + subset_name = "test_test" + + split_dir = make_split( + sub_ses_t1, + output_dir=caps_dir / "test", + subset_name=subset_name, + stratification=stratification, + n_test=n_test, + ) + + train_path = split_dir / "train_baseline.tsv" + test_path = split_dir / f"{subset_name}_baseline.tsv" + + assert train_path.exists() + assert test_path.exists() + + assert (split_dir / "single_split_config.json").is_file + with (split_dir / "single_split_config.json").open(mode="r") as file: + dict_ = json.load(file) + + assert dict_["json_name"] == "single_split_config.json" + assert dict_["split_dir"] == str(split_dir) + assert dict_["subset_name"] == subset_name + assert dict_["stratification"] == stratification + assert dict_["valid_longitudinal"] is False + assert dict_["n_test"] == n_test + assert np.isclose(dict_["p_categorical_threshold"], 0.5, rtol=1e-09, atol=1e-09) + assert np.isclose(dict_["p_continuous_threshold"], 0.5, rtol=1e-09, atol=1e-09) + + train_df = pd.read_csv(train_path, sep="\t") + test_df = pd.read_csv(test_path, sep="\t") + + assert len(test_df) == 15 + assert set(stratification).issubset(set(test_df.columns)) + + assert (split_dir / "split_continuous_stats.tsv").is_file() + assert (split_dir / "split_categorical_stats.tsv").is_file() + + split_dir_bis = make_split(sub_ses_t1, n_test=n_test) + + assert split_dir_bis == sub_ses_t1.parent / "split" + + split_dir_bis_bis = make_split(sub_ses_t1, n_test=n_test, stratification=False) + + assert split_dir_bis_bis == sub_ses_t1.parent / "split_2" + + remove_non_empty_dir(split_dir) + remove_non_empty_dir(split_dir_bis) + remove_non_empty_dir(split_dir_bis_bis) + + +def test_bad_split(): + with pytest.raises(ClinicaDLTSVError): + make_split(caps_dir / "test.tsv", n_test=15) + + with pytest.raises(ClinicaDLTSVError): + make_split(caps_dir / "subject_false.tsv", n_test=2) + + with pytest.raises(ValueError): + make_split(sub_ses_t1, p_categorical_threshold=12, n_test=2) + + with pytest.raises(ValueError): + make_split(sub_ses_t1, n_test=100) + + split_dir = sub_ses_t1.parent / "split" + remove_non_empty_dir(split_dir) + + +def test_good_kfold(): + n_split = 2 + stratification = "sex" + subset_name = "test_test" + + split_dir = make_kfold( + sub_ses_t1, + output_dir=caps_dir / "test", + subset_name=subset_name, + stratification=stratification, + n_splits=n_split, + ) + + train_path = split_dir / "split-0" / "train_baseline.tsv" + test_path = split_dir / "split-0" / f"{subset_name}_baseline.tsv" + + assert train_path.exists() + assert test_path.exists() + + assert (split_dir / "kfold_config.json").is_file + with (split_dir / "kfold_config.json").open(mode="r") as file: + dict_ = json.load(file) + + assert dict_["json_name"] == "kfold_config.json" + assert dict_["split_dir"] == str(split_dir) + assert dict_["subset_name"] == subset_name + assert dict_["stratification"] == stratification + assert dict_["valid_longitudinal"] is False + assert dict_["n_splits"] == n_split + + test_df = pd.read_csv(test_path, sep="\t") + + assert len(test_df) == 20 + assert set([stratification]).issubset(set(test_df.columns)) + + split_dir_bis = make_kfold( + sub_ses_t1, n_splits=n_split, stratification=stratification + ) + + assert split_dir_bis == sub_ses_t1.parent / "2_fold" + split_dir_bis_bis = make_kfold(sub_ses_t1, n_splits=n_split, stratification=False) + + assert split_dir_bis_bis == sub_ses_t1.parent / "2_fold_2" + + remove_non_empty_dir(split_dir) + remove_non_empty_dir(split_dir_bis) + remove_non_empty_dir(split_dir_bis_bis) + + +def test_bad_kfold(): + with pytest.raises(ClinicaDLTSVError): + make_kfold(caps_dir / "test.tsv", output_dir=caps_dir / "test_kfold") + + with pytest.raises(ClinicaDLTSVError): + make_kfold( + caps_dir / "subject_false.tsv", + n_splits=1, + output_dir=caps_dir / "test_kfold", + ) + + with pytest.raises(ValueError): + make_kfold( + sub_ses_t1, + stratification="age", + output_dir=caps_dir / "test_kfold", + ) + + with pytest.raises(ValidationError): + make_kfold( + sub_ses_t1, + stratification=["sex", "age"], + output_dir=caps_dir / "test_kfold", + ) # type: ignore + + with pytest.raises(ClinicaDLConfigurationError): + make_kfold( + sub_ses_t1, stratification="column", output_dir=caps_dir / "test_kfold" + ) + + remove_non_empty_dir(caps_dir / "test_kfold") diff --git a/tests/unittests/splitter/test_splitter.py b/tests/unittests/splitter/test_splitter.py new file mode 100644 index 000000000..8cf260bfd --- /dev/null +++ b/tests/unittests/splitter/test_splitter.py @@ -0,0 +1,105 @@ +import json +from pathlib import Path + +import nibabel as nib +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.preprocessing import PreprocessingT1, PreprocessingT2 +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.kfold import KFold, KFoldConfig +from clinicadl.splitter.splitter.single_split import SingleSplit, SingleSplitConfig +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) +from clinicadl.transforms import Transforms +from clinicadl.utils.enum import Preprocessing +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLCAPSError, + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) + +caps_dir = Path(__file__).parents[1] / "ressources" / "caps_example" +split_dir = caps_dir / "split_test" / "split" +fold_path = split_dir / "2_fold" + + +def test_single_splitter(): + config = SingleSplitConfig(split_dir=split_dir) + + assert config.subset_name == "test" + assert config.stratification is False + assert config.valid_longitudinal is False + assert np.isclose(config.p_categorical_threshold, 0.8, rtol=1e-09, atol=1e-09) + assert np.isclose(config.p_categorical_threshold, 0.8, rtol=1e-09, atol=1e-09) + assert config.json_name == "single_split_config.json" + assert config.n_test == 100 + + with pytest.raises(ValidationError): + SingleSplitConfig(split_dir=split_dir, p_categorical_threshold=12) + + +def test_single_split(): + splitter = SingleSplit(split_dir=split_dir) + + with pytest.raises(ClinicaDLTSVError): + splitter.get_splits( + dataset=CapsDataset(caps_dir, PreprocessingT1(), Transforms()) + ) + + with pytest.raises(FileNotFoundError): + splitter._read_split(Path("doesnt_exist")) + + with pytest.raises(FileNotFoundError): + splitter._read_split(caps_dir / "test") + + with pytest.raises(FileNotFoundError): + SingleSplit("doesnt_exist") + + +def test_kfold_splitter(): + config = KFoldConfig(split_dir=fold_path) + + assert config.subset_name == "validation" + assert config.stratification is False + assert config.valid_longitudinal is False + assert config.json_name == "kfold_config.json" + assert config.n_splits == 5 + + +def test_kfold(): + kfold = KFold(split_dir=fold_path) + config = kfold.config + assert config.subset_name == "validation" + assert config.stratification == "sex" + assert config.valid_longitudinal is False + assert config.json_name == "kfold_config.json" + assert config.n_splits == 2 + + assert isinstance(kfold.subjects_sessions_split[0], SubjectsSessionsSplit) + + with pytest.raises(ClinicaDLTSVError): + splits = list( + kfold.get_splits( + dataset=CapsDataset(caps_dir, PreprocessingT1(), Transforms()) + ) + ) + + with pytest.raises(FileNotFoundError): + kfold._read_split(Path("doesnt_exist")) + + with pytest.raises(FileNotFoundError): + kfold._read_split(caps_dir / "test") + + with pytest.raises(FileNotFoundError): + KFold("doesnt_exist") + + with pytest.raises(FileNotFoundError): + KFold(caps_dir / "test") diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index e0bdef815..07b07fd8f 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -158,8 +158,6 @@ def good_inputs(dummy_arguments): # assert c.transforms.size_reduction_factor == 5 # assert c.split.split == (0,) # assert c.early_stopping.min_delta == 0.0 - - # Test config manipulation # def test_assignment(dummy_arguments, training_config): c = training_config(**dummy_arguments)