From 621cc96cf58b2bf45084c72825a0026acc8f8bba Mon Sep 17 00:00:00 2001 From: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:14:47 +0100 Subject: [PATCH 1/3] First draft for KFold (#684) * make_split * make_kfold * KFold *SingleSplit --------- Co-authored-by: camillebrianceau Co-authored-by: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> --- clinicadl/API/complicated_case.py | 137 +++--- clinicadl/API/cross_val.py | 74 +-- .../commandline/modules_options/split.py | 9 +- clinicadl/dataset/dataloader/__init__.py | 1 + clinicadl/dataset/dataloader/config.py | 129 +++++ clinicadl/dataset/dataloader/defaults.py | 14 + clinicadl/dataset/datasets/caps_dataset.py | 49 +- clinicadl/dataset/utils.py | 56 ++- .../experiment_manager/experiment_manager.py | 9 +- clinicadl/experiment_manager/maps_manager.py | 6 - clinicadl/interpret/config.py | 2 +- clinicadl/model/clinicadl_model.py | 2 +- clinicadl/predictor/config.py | 2 +- clinicadl/predictor/old_predictor.py | 3 +- clinicadl/splitter/__init__.py | 3 + clinicadl/splitter/config.py | 71 --- clinicadl/splitter/kfold.py | 24 - clinicadl/splitter/make_splits/__init__.py | 2 + clinicadl/splitter/make_splits/kfold.py | 178 +++++++ .../splitter/make_splits/single_split.py | 439 ++++++++++++++++++ clinicadl/splitter/make_splits/utils.py | 78 ++++ clinicadl/splitter/old_splitter.py | 237 ---------- clinicadl/splitter/split.py | 169 ++++++- clinicadl/splitter/split_utils.py | 113 +++++ clinicadl/splitter/splitter/__init__.py | 2 + clinicadl/splitter/splitter/kfold.py | 108 +++++ clinicadl/splitter/splitter/single_split.py | 96 ++++ clinicadl/splitter/splitter/splitter.py | 256 ++++++++++ clinicadl/splitter/test.py | 168 +++++++ clinicadl/trainer/config/train.py | 4 +- clinicadl/trainer/old_trainer.py | 5 +- .../tsvtools/get_metadata/get_metadata.py | 2 + clinicadl/tsvtools/split/split.py | 29 -- clinicadl/tsvtools/tsvtools_utils.py | 3 +- tests/unittests/dataset/test_config.py | 2 +- .../split_test/split/2_fold/kfold_config.json | 8 + .../split_test/split/2_fold/split-0/train.tsv | 49 ++ .../split/2_fold/split-0/train_baseline.tsv | 49 ++ .../2_fold/split-0/validation_baseline.tsv | 49 ++ .../split_test/split/2_fold/split-1/train.tsv | 49 ++ .../split/2_fold/split-1/train_baseline.tsv | 49 ++ .../2_fold/split-1/validation_baseline.tsv | 49 ++ .../split_test/split/single_split_config.json | 15 + .../split/split_categorical_stats.tsv | 11 + .../split/split_continuous_stats.tsv | 5 + .../split_test/split/test_baseline.tsv | 25 + .../caps_example/split_test/split/train.tsv | 155 +++++++ .../split_test/split/train_baseline.tsv | 97 ++++ .../caps_example/subjects_false.tsv | 2 + .../caps_example/subjects_sessions_list.tsv | 8 - .../ressources/caps_example/subjects_t1.tsv | 65 +++ tests/unittests/splitter/test_make_split.py | 195 ++++++++ tests/unittests/splitter/test_splitter.py | 105 +++++ .../train/trainer/test_training_config.py | 2 - 54 files changed, 2923 insertions(+), 546 deletions(-) create mode 100644 clinicadl/dataset/dataloader/__init__.py create mode 100644 clinicadl/dataset/dataloader/config.py create mode 100644 clinicadl/dataset/dataloader/defaults.py delete mode 100644 clinicadl/splitter/config.py delete mode 100644 clinicadl/splitter/kfold.py create mode 100644 clinicadl/splitter/make_splits/__init__.py create mode 100644 clinicadl/splitter/make_splits/kfold.py create mode 100644 clinicadl/splitter/make_splits/single_split.py create mode 100644 clinicadl/splitter/make_splits/utils.py delete mode 100644 clinicadl/splitter/old_splitter.py create mode 100644 clinicadl/splitter/splitter/__init__.py create mode 100644 clinicadl/splitter/splitter/kfold.py create mode 100644 clinicadl/splitter/splitter/single_split.py create mode 100644 clinicadl/splitter/splitter/splitter.py create mode 100644 clinicadl/splitter/test.py create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/single_split_config.json create mode 100644 tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/train.tsv create mode 100644 tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv create mode 100644 tests/unittests/ressources/caps_example/subjects_false.tsv create mode 100644 tests/unittests/ressources/caps_example/subjects_t1.tsv create mode 100644 tests/unittests/splitter/test_make_split.py create mode 100644 tests/unittests/splitter/test_splitter.py diff --git a/clinicadl/API/complicated_case.py b/clinicadl/API/complicated_case.py index aa126629b..4afe5050a 100644 --- a/clinicadl/API/complicated_case.py +++ b/clinicadl/API/complicated_case.py @@ -2,19 +2,15 @@ import torchio.transforms as transforms -from clinicadl.dataset.caps_dataset import ( - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, -) from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.datasets.concat import ConcatDataset +from clinicadl.dataset.preprocessing import ( + PreprocessingCustom, + PreprocessingPET, + PreprocessingT1, +) +from clinicadl.dataset.readers.caps_reader import CapsReader from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig from clinicadl.losses.factory import get_loss_function @@ -31,6 +27,7 @@ from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split import get_single_split, split_tsv from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice from clinicadl.transforms.transforms import Transforms # Create the Maps Manager / Read/write manager / @@ -40,60 +37,47 @@ ) # a ajouter dans le manager: mlflow/ profiler/ etc ... caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) - -preprocessing_1 = caps_reader.get_preprocessing("t1-linear") -caps_reader.prepare_data( - preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 -) # don't return anything -> just extract the image tensor and compute some information for each images - -transforms_1 = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractionPatchConfig(patch_size=3), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], -) # not mandatory - -preprocessing_2 = caps_reader.get_preprocessing("pet-linear") -caps_reader.prepare_data( - preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 -) # to extract the tensor of the PET file this time -transforms_2 = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractionSliceConfig(), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], + +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +preprocessing_t1 = PreprocessingT1() +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) + +print("T1 and image ") + +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_t1, + preprocessing=preprocessing_t1, + transforms=transforms_image, +) +dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the T1 file + + +sub_ses_pet_45 = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv" ) +preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") -sub_ses_tsv_1 = Path("") -split_dir_1 = split_tsv(sub_ses_tsv_1) # -> creer un test.tsv et un train.tsv - -sub_ses_tsv_2 = Path("") -split_dir_2 = split_tsv(sub_ses_tsv_2) - -dataset_t1_roi = caps_reader.get_dataset( - preprocessing=preprocessing_1, - sub_ses_tsv=split_dir_1 / "train.tsv", - transforms=transforms_1, -) # do we give config or object for transforms ? -dataset_pet_patch = caps_reader.get_dataset( - preprocessing=preprocessing_2, - sub_ses_tsv=split_dir_2 / "train.tsv", - transforms=transforms_2, +dataset_pet_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_pet_45, + preprocessing=preprocessing_pet_45, + transforms=transforms_image, ) +dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the PET file + -dataset_multi_modality_multi_extract = ConcatDataset( +dataset_multi_modality = ConcatDataset( [ - dataset_t1_roi, - dataset_pet_patch, - caps_reader.get_dataset_from_json(json_path=Path("dataset.json")), + dataset_t1_image, + dataset_pet_image, ] ) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention -# TODO : think about adding transforms in extract_json - config_file = Path("config_file") trainer = Trainer.from_json(config_file=config_file, manager=manager) @@ -106,6 +90,26 @@ dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) + +# CAS 1 + +# Prérequis : déjà avoir des fichiers avec les listes train et validation +split_dir = make_kfold( + "dataset.tsv" +) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir +splitter = KFolder( + dataset_multi_modality, split_dir +) # c'est plutôt un iterable de dataloader + +# CAS 2 +splitter = KFolder(caps_dataset=dataset_t1_image) +splitter.make_splits(n_splits=3) +splitter.write(split_dir) + +# or +splitter = KFolder(caps_dataset=dataset_t1_image) +splitter.read(split_dir) + for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): # bien définir ce qu'il y a dans l'objet split @@ -125,19 +129,12 @@ # TEST -preprocessing_test = caps_reader.get_preprocessing("pet-linear") -transforms_test = Transforms( - object_augmentation=[transforms.Crop, transforms.Transform], - image_augmentation=[transforms.Crop, transforms.Transform], - extraction=ExtractioImageConfig(), - image_transforms=[transforms.Blur, transforms.Ghosting], - object_transforms=[transforms.BiasField, transforms.Motion], -) -dataset_test = caps_reader.get_dataset( - preprocessing=preprocessing_test, - sub_ses_tsv=split_dir_1 / "test.tsv", # test only on data from the first dataset - transforms=transforms_test, +dataset_test = CapsDataset( + caps_directory=caps_directory, + preprocessing=preprocessing_t1, + sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset + transforms=transforms_image, ) predictor = Predictor(model=model, manager=manager) diff --git a/clinicadl/API/cross_val.py b/clinicadl/API/cross_val.py index 0efa2f195..54f7b9d6b 100644 --- a/clinicadl/API/cross_val.py +++ b/clinicadl/API/cross_val.py @@ -1,80 +1,38 @@ from pathlib import Path -import torchio.transforms as transforms - -from clinicadl.dataset.caps_dataset import ( - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.datasets.caps_dataset import CapsDataset from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.losses.config import CrossEntropyLossConfig -from clinicadl.losses.factory import get_loss_function -from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.networks.config import ImplementedNetworks -from clinicadl.networks.factory import ( - ConvEncoderOptions, - create_network_config, - get_network_from_config, -) -from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig -from clinicadl.optimization.optimizer.factory import get_optimizer from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig +from clinicadl.splitter.new_splitter.splitter.kfold import KFold from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.config import TransformsConfig # SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING maps_path = Path("/") manager = ExperimentManager(maps_path, overwrite=False) -dataset_t1_image = CapsDatasetPatch.from_json( - extraction=Path("test.json"), - sub_ses_tsv=Path("split_dir") / "train.tsv", -) +dataset_t1_image = CapsDataset.from_json(Path("json_path.json")) + config_file = Path("config_file") trainer = Trainer.from_json( config_file=config_file, manager=manager ) # gpu, amp, fsdp, seed -# CAS CROSS-VALIDATION -splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager) -split_dir = splitter.make_splits( - n_splits=3, - output_dir=Path(""), - data_tsv=Path("labels.tsv"), - subset_name="validation", - stratification="", -) # Optional data tsv and output_dir -# n_splits must be >1 -# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here +splitter = KFold(dataset=dataset_t1_image) +splitter.make_splits(n_splits=3) +split_dir = Path("") +splitter.write(split_dir) -# CAS EXISTING CROSS-VALIDATION -splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager) +splitter.read(split_dir) # define the needed parameters for the dataloader -dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) +dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10) -for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): - # bien définir ce qu'il y a dans l'objet split - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer) +for split in splitter.get_splits(splits=(0, 3, 4)): + print(split) + split.build_train_loader(dataloader_config) + split.build_val_loader(num_workers=3, batch_size=10) - trainer.train(model, split) - # le trainer va instancier un predictor/valdiator dans le train ou dans le init + print(split) diff --git a/clinicadl/commandline/modules_options/split.py b/clinicadl/commandline/modules_options/split.py index f7c0a8882..4579a6f7f 100644 --- a/clinicadl/commandline/modules_options/split.py +++ b/clinicadl/commandline/modules_options/split.py @@ -2,13 +2,14 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.kfold import KFoldConfig +from clinicadl.splitter.splitter.single_split import SingleSplitConfig # Cross Validation n_splits = click.option( "--n_splits", - type=get_type("n_splits", SplitConfig), - default=get_default("n_splits", SplitConfig), + type=get_type("n_splits", KFoldConfig), + default=get_default("n_splits", KFoldConfig), help="If a value is given for k will load data of a k-fold CV. " "Default value (0) will load a single split.", show_default=True, @@ -17,7 +18,7 @@ "--split", "-s", type=int, # get_type("split", config.ValidationConfig), - default=get_default("split", SplitConfig), + default=get_default("split", SingleSplitConfig), multiple=True, help="Train the list of given splits. By default, all the splits are trained.", show_default=True, diff --git a/clinicadl/dataset/dataloader/__init__.py b/clinicadl/dataset/dataloader/__init__.py new file mode 100644 index 000000000..eda7a269a --- /dev/null +++ b/clinicadl/dataset/dataloader/__init__.py @@ -0,0 +1 @@ +from .config import DataLoaderConfig diff --git a/clinicadl/dataset/dataloader/config.py b/clinicadl/dataset/dataloader/config.py new file mode 100644 index 000000000..7c4acfb6d --- /dev/null +++ b/clinicadl/dataset/dataloader/config.py @@ -0,0 +1,129 @@ +from typing import Optional + +from pydantic import NonNegativeInt, PositiveInt +from torch.utils.data import DataLoader, DistributedSampler, Sampler +from torch.utils.data import WeightedRandomSampler as BaseWeightedRandomSampler + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.utils.config import ClinicaDLConfig +from clinicadl.utils.seed import pl_worker_init_function + +from .defaults import ( + BATCH_SIZE, + DP_DEGREE, + DROP_LAST, + NUM_WORKERS, + PIN_MEMORY, + PREFETCH_FACTOR, + RANK, + SAMPLING_WEIGHTS, + SHUFFLE, +) + + +class WeightedRandomSampler(BaseWeightedRandomSampler): + """ + Modifies PyTorch's WeightedRandomSampler to have a similar behavior to + PyTorch's DistributedSampler. + """ + + def set_epoch(self, epoch: int) -> None: + """ + Fake method to simulate 'set_epoch' of PyTorch's DistributedSampler. + To be able to always call sampler.set_epoch(), no matter the sampler. + """ + + +class DataLoaderConfig(ClinicaDLConfig): + """Config class to parametrize a PyTorch DataLoader.""" + + batch_size: PositiveInt = BATCH_SIZE + sampling_weights: Optional[str] = SAMPLING_WEIGHTS + shuffle: bool = SHUFFLE + drop_last: bool = DROP_LAST + num_workers: NonNegativeInt = NUM_WORKERS + prefetch_factor: Optional[NonNegativeInt] = PREFETCH_FACTOR + pin_memory: bool = PIN_MEMORY + + def _generate_sampler( + self, + dataset: CapsDataset, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> Sampler: + """ + Returns a WeightedRandomSampler if self.sampling_weights is not None, otherwise a + a DistributedSampler, even when data parallelism is not performed (in this case + the degree of data parallelism is set to 1, so it is equivalent to a simple PyTorch + RandomSampler if self.shuffle is True or no sampler if self.shuffle is False). + """ + if (rank is not None and dp_degree is None) or ( + dp_degree is not None and rank is None + ): + raise ValueError( + "For data parallelism, none of 'dp_degree' and 'rank' can be None. " + f"Got rank={rank} and dp_degree={dp_degree}" + ) + distributed = dp_degree is not None + + if self.sampling_weights: + try: + weights = dataset.df[self.sampling_weights].values.astype(float) + except KeyError as exc: + raise KeyError( + f"Got {self.sampling_weights} for 'sampling_weights' but there is no " + "column named like that in the dataframe of the dataset." + ) from exc + length = ( + len(weights) // dp_degree + int(rank < len(weights) % dp_degree) + if distributed + else len(weights) + ) + sampler = WeightedRandomSampler(weights, num_samples=length) # type: ignore + else: + if not distributed: + dp_degree = 1 + rank = 0 + sampler = DistributedSampler( + dataset, + num_replicas=dp_degree, + rank=rank, + shuffle=self.shuffle, + drop_last=False, # not the same as self.drop_last + ) + + return sampler + + def get_dataloader( + self, + dataset: CapsDataset, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> DataLoader: + """ + To get a dataloader from a dataset. The dataloader is parametrized + with the options stored in this configuration class. + + Parameters + ---------- + dataset : CapsDataset + The dataset to put in a Dataloader. + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Returns + ------- + DataLoader + The dataloader that wraps the dataset. + """ + loader = DataLoader( + dataset=dataset, + sampler=self._generate_sampler(dataset, dp_degree, rank), + worker_init_fn=pl_worker_init_function, + **self.model_dump(), + ) + + return loader diff --git a/clinicadl/dataset/dataloader/defaults.py b/clinicadl/dataset/dataloader/defaults.py new file mode 100644 index 000000000..ed493d9fb --- /dev/null +++ b/clinicadl/dataset/dataloader/defaults.py @@ -0,0 +1,14 @@ +BATCH_SIZE = 1 + +SAMPLING_WEIGHTS = None # no weighted sampling + +SHUFFLE = False +DROP_LAST = False + +NUM_WORKERS = 0 # main process loads data +PREFETCH_FACTOR = None + +PIN_MEMORY = True # training is supposed to be on a GPU + +DP_DEGREE = None # no data parallelism +RANK = None diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/dataset/datasets/caps_dataset.py index 95535d73f..5a877c8d2 100644 --- a/clinicadl/dataset/datasets/caps_dataset.py +++ b/clinicadl/dataset/datasets/caps_dataset.py @@ -1,5 +1,7 @@ # coding: utf8 +from __future__ import annotations +from copy import deepcopy from logging import getLogger from pathlib import Path from typing import List, Optional, Tuple, Union @@ -12,6 +14,7 @@ from torch import save as save_tensor from torch.utils.data import Dataset from tqdm import tqdm +from typing_extensions import Self from clinicadl.dataset.preprocessing import BasePreprocessing from clinicadl.dataset.readers.caps_reader import CapsReader @@ -190,7 +193,18 @@ def _get_df_from_input( ) logger.info(f"Creating a subject session TSV file at {data}") - elif isinstance(data, str): + df = self._check_data_instance(data) + self.df = df + + if not self._check_preprocessing_config(): + raise ClinicaDLCAPSError( + f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" + ) + + return df + + def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path]] = None): + if isinstance(data, str): data = Path(data) if isinstance(data, Path): @@ -200,15 +214,9 @@ def _get_df_from_input( "Please ensure the file path is correct and accessible." ) df = tsv_to_df(data) - elif isinstance(data, pd.DataFrame): + if isinstance(data, pd.DataFrame): df = check_df(data) - self.df = df - if not self._check_preprocessing_config(): - raise ClinicaDLCAPSError( - f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" - ) - return df def _check_preprocessing_config(self) -> bool: @@ -535,3 +543,28 @@ def prepare_image(participant, session): self._get_participants_sessions_couple(), desc="Preparing data" ) ) + + def subset(self, data: Optional[Union[pd.DataFrame, Path]] = None) -> CapsDataset: + df = self._check_data_instance(data) + + common_rows = pd.merge(df, self.df, how="inner") + all_included = len(common_rows) == len(df) + + if not all_included: + missing_rows = pd.concat( + [df, common_rows], ignore_index=True + ).drop_duplicates(keep=False) + + err_message = "Missing rows: \n" + for row in missing_rows: + err_message += f" - {row} \n" + + raise ClinicaDLTSVError( + "Some couples (participanst_id, session_id) are not in the dataset,", + err_message, + ) + + dataset = deepcopy(self) + dataset.df = df + + return dataset diff --git a/clinicadl/dataset/utils.py b/clinicadl/dataset/utils.py index 6316b00bd..a1f2c1889 100644 --- a/clinicadl/dataset/utils.py +++ b/clinicadl/dataset/utils.py @@ -6,6 +6,7 @@ import pandas as pd import torch +import torchio as tio from pydantic import BaseModel, ConfigDict from clinicadl.dataset import preprocessing @@ -51,6 +52,34 @@ class CapsDatasetSample(BaseModel): model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) +def df_to_tsv(name: str, results_path: Path, df, baseline: bool = False) -> None: + """ + Write Dataframe into a TSV file and drop duplicates + + Parameters + ---------- + name: str + Name of the tsv file + results_path: str (path) + Path to the folder + df: DataFrame + DataFrame you want to write in a TSV file. + Columns must include ["participant_id", "session_id"]. + baseline: bool + If True, there is only baseline session for each subject. + """ + + df.sort_values(by=["participant_id", "session_id"], inplace=True) + if baseline: + df.drop_duplicates(subset=["participant_id"], keep="first", inplace=True) + else: + df.drop_duplicates( + subset=["participant_id", "session_id"], keep="first", inplace=True + ) + # df = df[["participant_id", "session_id"]] + df.to_csv(results_path / name, sep="\t", index=False) + + def tsv_to_df(tsv_path: Path) -> pd.DataFrame: """ Converts a TSV file to a Pandas DataFrame. @@ -94,7 +123,32 @@ def check_df(df: pd.DataFrame) -> pd.DataFrame: f"The data file is not in the correct format. " f"Columns should include {PARTICIPANT_ID, SESSION_ID}" ) - df.reset_index(inplace=True) + + return df + + +def reset_index(df: pd.DataFrame) -> pd.DataFrame: + """ + Resets the index of a DataFrame to the default index, dropping any existing index. + + Args: + df (pd.DataFrame): The DataFrame to be reset. + + Returns: + pd.DataFrame: The DataFrame with the default index. + + Note: + This function only resets the index if the DataFrame has a MultiIndex with the 'participant_id' and'session_id' names. + If the DataFrame does not have this MultiIndex, the 'drop' parameter is set to True, which results in dropping the index. + """ + + drop = False + if isinstance(df.index, pd.MultiIndex): + if set(df.index.names) != {PARTICIPANT_ID, SESSION_ID}: + drop = True + + df.reset_index(inplace=True, drop=drop) + return df diff --git a/clinicadl/experiment_manager/experiment_manager.py b/clinicadl/experiment_manager/experiment_manager.py index 6edad60d6..dea504e6e 100644 --- a/clinicadl/experiment_manager/experiment_manager.py +++ b/clinicadl/experiment_manager/experiment_manager.py @@ -8,15 +8,14 @@ import pandas as pd from pydantic import BaseModel -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.preprocessing import PreprocessingConfig +from clinicadl.dataset.preprocessing import BasePreprocessing +from clinicadl.dataset.readers import CapsReader from clinicadl.metrics.old_metrics.utils import check_selection_metric from clinicadl.model.clinicadl_model import ClinicaDLModel from clinicadl.networks.config import NetworkConfig from clinicadl.networks.factory import get_network_from_config -from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split_utils import print_description_log +from clinicadl.transforms.extraction import Extraction from clinicadl.utils.exceptions import MAPSError from clinicadl.utils.iotools.data_utils import load_data_test from clinicadl.utils.iotools.utils import path_decoder, path_encoder @@ -71,7 +70,7 @@ def information_log(self) -> Path: def get_info_from_json( self, - ) -> tuple[PreprocessingConfig, ExtractionConfig, CapsReader, ClinicaDLModel]: + ) -> tuple[PreprocessingConfig, Extraction, CapsReader, ClinicaDLModel]: """Reads the maps.json file and returns its content.""" # I don't know if this is a useful function if self.maps_json.is_file(): diff --git a/clinicadl/experiment_manager/maps_manager.py b/clinicadl/experiment_manager/maps_manager.py index c8c9a8eb4..0a3903136 100644 --- a/clinicadl/experiment_manager/maps_manager.py +++ b/clinicadl/experiment_manager/maps_manager.py @@ -9,17 +9,11 @@ import pandas as pd import torch -from clinicadl.dataset.caps_dataset import ( - return_dataset, -) -from clinicadl.dataset.caps_dataset.caps_dataset_utils import read_json from clinicadl.metrics.old_metrics.metric_module import MetricModule from clinicadl.metrics.old_metrics.utils import ( check_selection_metric, ) from clinicadl.predictor.utils import get_prediction -from clinicadl.splitter.config import SplitterConfig -from clinicadl.splitter.old_splitter import Splitter from clinicadl.trainer.tasks_utils import ( ensemble_prediction, evaluation_metrics, diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index d5cf76390..5a1ebeaf3 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -11,7 +11,7 @@ from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod diff --git a/clinicadl/model/clinicadl_model.py b/clinicadl/model/clinicadl_model.py index 8785ed97b..27adad432 100644 --- a/clinicadl/model/clinicadl_model.py +++ b/clinicadl/model/clinicadl_model.py @@ -3,6 +3,6 @@ class ClinicaDLModel: - def __init__(self, network: nn.Module, loss: nn.Module, optimizer=optim.optimizer): + def __init__(self, network: nn.Module, loss: nn.Module, optimizer): """TO COMPLETE""" pass diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py index eaa3a8653..2c793045c 100644 --- a/clinicadl/predictor/config.py +++ b/clinicadl/predictor/config.py @@ -9,7 +9,7 @@ ) from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import Task diff --git a/clinicadl/predictor/old_predictor.py b/clinicadl/predictor/old_predictor.py index 8314ce9d9..96d4764a5 100644 --- a/clinicadl/predictor/old_predictor.py +++ b/clinicadl/predictor/old_predictor.py @@ -48,8 +48,7 @@ class Predictor: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: self._config = _config - from clinicadl.splitter.config import SplitterConfig - from clinicadl.splitter.old_splitter import Splitter + from clinicadl.splitter.splitter.splitter import Splitter, SplitterConfig self.maps_manager = MapsManager(_config.maps_manager.maps_dir) self._config.adapt_with_maps_manager_info(self.maps_manager) diff --git a/clinicadl/splitter/__init__.py b/clinicadl/splitter/__init__.py index e69de29bb..40a3cf136 100644 --- a/clinicadl/splitter/__init__.py +++ b/clinicadl/splitter/__init__.py @@ -0,0 +1,3 @@ +from .make_splits import make_kfold, make_split +from .split import Split +from .splitter import KFold, SingleSplit diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py deleted file mode 100644 index 050b3c6c1..000000000 --- a/clinicadl/splitter/config.py +++ /dev/null @@ -1,71 +0,0 @@ -from abc import ABC, abstractmethod -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from pydantic import BaseModel, ConfigDict, field_validator -from pydantic.types import NonNegativeInt - -from clinicadl.dataset.config.data import DataConfig -from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.split_utils import find_splits - -logger = getLogger("clinicadl.split_config") - - -class SplitConfig(BaseModel): - """ - Abstract config class for the validation procedure. - - selection_metrics is specific to the task, thus it needs - to be specified in a subclass. - """ - - n_splits: NonNegativeInt = 0 - split: Optional[Tuple[NonNegativeInt, ...]] = None - tsv_path: Optional[Path] = None # not needed in interpret ! - - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @field_validator("split", mode="before") - def validator_split(cls, v): - if isinstance(v, list): - return tuple(v) - return v # TODO : check that split exists (and check coherence with n_splits) - - def adapt_cross_val_with_maps_manager_info( - self, maps_manager - ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future - # TEMPORARY - if not self.split: - self.split = tuple(find_splits(maps_manager.maps_path)) - logger.debug(f"List of splits {self.split}") - - -class SplitterConfig(BaseModel, ABC): - """ - - Abstract config class for the training pipeline. - Some configurations are specific to the task (e.g. loss function), - thus they need to be specified in a subclass. - """ - - data: DataConfig - split: SplitConfig - validation: ValidationConfig - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - def __init__(self, **kwargs): - super().__init__( - data=kwargs, - split=kwargs, - validation=kwargs, - ) - - def _update(self, config_dict: Dict[str, Any]) -> None: - """Updates the configs with a dict given by the user.""" - self.data.__dict__.update(config_dict) - self.split.__dict__.update(config_dict) - self.validation.__dict__.update(config_dict) diff --git a/clinicadl/splitter/kfold.py b/clinicadl/splitter/kfold.py deleted file mode 100644 index 37805dfba..000000000 --- a/clinicadl/splitter/kfold.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional - -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.experiment_manager.experiment_manager import ExperimentManager - - -class Split: - def __init__( - self, - ): - """TO COMPLETE""" - pass - - -class KFolder: - def __init__( - self, n_splits: int, caps_dataset: CapsDataset, manager: ExperimentManager - ) -> None: - """TO COMPLETE""" - - def split_iterator(self, split_list: Optional[list] = None) -> list[Split]: - """TO COMPLETE""" - - return list[Split()] diff --git a/clinicadl/splitter/make_splits/__init__.py b/clinicadl/splitter/make_splits/__init__.py new file mode 100644 index 000000000..00578866b --- /dev/null +++ b/clinicadl/splitter/make_splits/__init__.py @@ -0,0 +1,2 @@ +from .kfold import make_kfold +from .single_split import make_split diff --git a/clinicadl/splitter/make_splits/kfold.py b/clinicadl/splitter/make_splits/kfold.py new file mode 100644 index 000000000..ed9dd41e2 --- /dev/null +++ b/clinicadl/splitter/make_splits/kfold.py @@ -0,0 +1,178 @@ +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from pydantic import PositiveInt +from sklearn.model_selection import KFold, StratifiedKFold + +from clinicadl.dataset.utils import tsv_to_df +from clinicadl.splitter.make_splits.utils import write_to_csv +from clinicadl.splitter.splitter.kfold import KFoldConfig +from clinicadl.tsvtools.tsvtools_utils import extract_baseline +from clinicadl.utils.exceptions import ClinicaDLConfigurationError + + +def _validate_stratification( + df: pd.DataFrame, + stratification: Union[str, bool], +) -> Optional[str]: + """ + Validates and checks the stratification columns. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[str, bool] + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + Optional[str] + Validated stratification column or None if no stratification is applied. + + Raises + ------ + ClinicaDLConfigurationError + If invalid or conflicting stratification options are provided. + """ + if isinstance(stratification, bool): + if stratification: + stratification = "sex" + else: + return None + + if isinstance(stratification, List): + if len(stratification) > 1: + raise ClinicaDLConfigurationError( + "Stratification can only be performed on a single column for K-Fold splitting." + ) + else: + stratification = stratification[0] + + if isinstance(stratification, str): + if stratification not in df.columns: + raise ClinicaDLConfigurationError( + f"Stratification column '{stratification}' not found in the dataset." + ) + + if pd.api.types.is_numeric_dtype(df[stratification]) and df[ + stratification + ].nunique() >= (len(df) / 2): + raise ValueError( + "Continuous variables cannot be used for stratification in K-Fold splitting." + ) + return stratification + + raise ClinicaDLConfigurationError( + "Invalid or conflicting stratification options provided. Stratification must be a single column name or boolean." + ) + + +def preprocess_stratification( + df: pd.DataFrame, + stratification: Union[str, bool], +) -> pd.DataFrame: + """ + Preprocess stratification columns by creating labels for each subject. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[str, bool] + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + List[str] + List of stratification labels for the dataset. + """ + column = _validate_stratification(df, stratification) + + if column is None: + return df + + return df[[column]] + + +def make_kfold( + tsv_path: Path, + output_dir: Optional[Union[Path, str]] = None, + subset_name: str = "validation", + valid_longitudinal: bool = False, + n_splits: PositiveInt = 5, + stratification: Union[str, bool] = False, +) -> Path: + """ + Perform K-Fold splitting with optional stratification. + + Parameters + ---------- + tsv_path : Path + Path to the input TSV file. + output_dir : Optional[Path] + Directory to save the split files. Defaults to the parent directory of `tsv_path`. + subset_name : str, default="validation" + Name of the subset used for output files. + valid_longitudinal : bool, default=False + Whether to include longitudinal sessions in the split. + n_splits : PositiveInt, default=5 + Number of splits for K-Fold. + stratification : Union[str, bool], default=False + Column to use for stratification. If True, column is 'sex', if False, there is no stratification. + + Returns + ------- + Path + Directory containing the generated split files. + + Raises + ------ + ClinicaDLConfigurationError + If invalid configuration options are provided. + """ + + # Set default output directory + output_dir = output_dir or tsv_path.parent + output_dir = Path(output_dir) + + # Initialize KFold configuration + config = KFoldConfig( + split_dir=output_dir, + subset_name=subset_name, + valid_longitudinal=valid_longitudinal, + n_splits=n_splits, + stratification=stratification, + ) + + config._check_split_dir() + config._write_json() + + # Load and process dataset + df = tsv_to_df(tsv_path) + baseline_df = extract_baseline(df) + + stratify_labels = preprocess_stratification( + df=baseline_df, + stratification=config.stratification, + ) + + # Create K-Fold splits + if config.stratification: + skf = StratifiedKFold(n_splits=config.n_splits, shuffle=True, random_state=2) + else: + skf = KFold(n_splits=config.n_splits, shuffle=True, random_state=2) + + for i, (train_idx, test_idx) in enumerate(skf.split(baseline_df, stratify_labels)): + train = baseline_df.iloc[train_idx] + test = baseline_df.iloc[test_idx] + + split_dir = config.split_dir / f"split-{i}" + split_dir.mkdir(parents=True, exist_ok=True) + + write_to_csv(test, split_dir, df, config.subset_name, config.valid_longitudinal) + write_to_csv(train, split_dir, df) + + return config.split_dir diff --git a/clinicadl/splitter/make_splits/single_split.py b/clinicadl/splitter/make_splits/single_split.py new file mode 100644 index 000000000..1854b5449 --- /dev/null +++ b/clinicadl/splitter/make_splits/single_split.py @@ -0,0 +1,439 @@ +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from pydantic import PositiveFloat +from scipy.stats import chisquare, ks_2samp, ttest_ind +from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit + +from clinicadl.dataset.utils import tsv_to_df +from clinicadl.splitter.make_splits.utils import write_to_csv +from clinicadl.splitter.splitter.single_split import SingleSplitConfig +from clinicadl.tsvtools.tsvtools_utils import extract_baseline +from clinicadl.utils.exceptions import ClinicaDLConfigurationError, ClinicaDLTSVError + +logger = getLogger("clinicadl.splitter.single_split") + + +def _validate_stratification( + df: pd.DataFrame, + stratification: Union[List[str], bool], +) -> List[str]: + """ + Checks and validates the specified stratification columns. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[List[str], bool] + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + + Returns + ------- + List[str], optional + Validated list of stratification columns or None if no stratification is applied. + + Raises + ------ + ValueError + If specified stratification columns are missing or if stratification conflicts with demographic handling. + ClinicaDLTSVError + If required demographic columns ('age', 'sex') are missing when not ignored. + """ + + if isinstance(stratification, bool): + if stratification: + stratification = ["age", "sex"] + else: + return [] + + if isinstance(stratification, list): + if not set(stratification).issubset(df.columns): + raise ValueError( + f"Invalid stratification columns: {set(stratification) - set(df.columns)}" + ) + return stratification + + raise ValueError( + "Invalid stratification option. Stratification must be a list of column names or a boolean." + ) + + +def _categorize_labels( + df: pd.DataFrame, + stratification: Union[List[str], bool], + n_test: int = 100, +) -> Tuple[List[str], List[str]]: + """ + Categorize stratification columns into continuous and categorical labels. + + Parameters + ---------- + df : pd.DataFrame + Input dataset. + stratification : Union[List[str], bool] + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + n_test : int + Number of test samples. + + Returns + ------- + Tuple[List[str], List[str]] + Continuous and categorical labels. + """ + columns = _validate_stratification(df, stratification) + + continuous_labels, categorical_labels = [], [] + for col in columns: + if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() >= (n_test / 2): + continuous_labels.append(col) + else: + categorical_labels.append(col) + return continuous_labels, categorical_labels + + +def _chi2_test(x_test: List[int], x_train: List[int]) -> float: + """ + Perform the Chi-squared test on categorical data. + + Parameters + ---------- + x_test : np.ndarray + Test data. + x_train : np.ndarray + Train data. + + Returns + ------- + float + p-value from the Chi-squared test. + """ + unique_categories = np.unique(np.concatenate([x_test, x_train])) + + # Calculate observed (test) and expected (train) frequencies as raw counts + f_obs = np.array([(x_test == category).sum() for category in unique_categories]) + f_exp = np.array( + [ + (x_train == category).sum() / len(x_train) * len(x_test) + for category in unique_categories + ] + ) + + _, p_value = chisquare(f_obs, f_exp) + + return p_value + + +def make_split( + tsv_path: Path, + output_dir: Optional[Union[Path, str]] = None, + n_test: PositiveFloat = 100, + subset_name: str = "test", + p_categorical_threshold: float = 0.50, + p_continuous_threshold: float = 0.50, + stratification: Union[List[str], bool] = False, + valid_longitudinal=False, + n_try_max: int = 1000, +): + """ + Perform a single train-test split of the dataset with stratification. + + Parameters + ---------- + tsv_path : Path + Path to the input TSV file. + output_dir : Optional[Path] + Directory to save the split files. + n_test : PositiveFloat + If >= 1, specifies the absolute number of test samples. If < 1, treated as a proportion of the dataset. + subset_name : str + Name for the test subset. + p_categorical_threshold : float + Threshold for acceptable categorical stratification. + p_continuous_threshold : float + Threshold for acceptable continuous stratification. + stratification : Union[List[str], bool], default=False + Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. + valid_longitudinal : bool + Include longitudinal sessions if True. + n_try_max : int + Maximum number of attempts to find a valid split. + + Returns + ------- + Path + Directory containing the split files. + """ + + # Set default output directory + output_dir = output_dir or tsv_path.parent + output_dir = Path(output_dir) + + # Load dataset and preprocess + df = tsv_to_df(tsv_path) + baseline_df = extract_baseline(df) + + n_test = int(n_test) if n_test >= 1 else int(n_test * len(baseline_df)) + + continuous_labels, categorical_labels = _categorize_labels( + df=baseline_df, + stratification=stratification, + n_test=n_test, + ) + + # Initialize SingleSplit configuration + config = SingleSplitConfig( + split_dir=output_dir, + subset_name=subset_name, + valid_longitudinal=valid_longitudinal, + n_test=n_test, + p_continuous_threshold=p_continuous_threshold, + p_categorical_threshold=p_categorical_threshold, + stratification=stratification, + ) + + config._check_split_dir() + config._write_json() + + if config.n_test > 0: + splits = ShuffleSplit( + n_splits=n_try_max, test_size=config.n_test, random_state=2 + ) + for n_try, (train_index, test_index) in enumerate( + splits.split(baseline_df, baseline_df) + ): + p_continuous = compute_continuous_p_value( + continuous_labels, + baseline_df, + train_index.tolist(), + test_index.tolist(), + ) + + if p_continuous >= p_continuous_threshold: + p_categorical = compute_categorical_p_value( + categorical_labels, + baseline_df, + train_index.tolist(), + test_index.tolist(), + ) + + if p_categorical >= p_categorical_threshold: + logger.info(f"Valid split found after {n_try} attempts.") + + test_df = baseline_df.loc[test_index] + train_df = baseline_df.loc[train_index] + + write_continuous_stats( + config.split_dir / "split_continuous_stats.tsv", + continuous_labels, + test_df, + train_df, + subset_name, + ) + write_categorical_stats( + config.split_dir / "split_categorical_stats.tsv", + categorical_labels, + test_df, + train_df, + baseline_df, + subset_name, + ) + break + + if n_try >= n_try_max - 1: + raise ClinicaDLConfigurationError( + f"Unable to find a valid split after {n_try} attempts. " + f"Consider lowering thresholds or reducing stratification variables." + ) + + write_to_csv(test_df, config.split_dir, df, subset_name, valid_longitudinal) + else: + train_df = baseline_df + + write_to_csv(train_df, config.split_dir, df) + + return config.split_dir + + +def compute_continuous_p_value( + continuous_labels: Optional[list[str]], + baseline_df: pd.DataFrame, + train_index: list[int], + test_index: list[int], +) -> float: + """ + Compute the minimum p-value for continuous variables between train and test splits. + + Parameters + ---------- + continuous_labels : Optional[List[str]] + List of continuous variable names. + baseline_df : pd.DataFrame + Dataframe containing the baseline data. + train_index : List[int] + Indices for the training set. + test_index : List[int] + Indices for the testing set. + + Returns + ------- + float + The minimum p-value across all continuous labels. + """ + + p_continuous = 1.0 + if continuous_labels: + for label in continuous_labels: + if len(baseline_df[label] != 1): + train_values = baseline_df[label].loc[train_index].values.tolist() + test_values = baseline_df[label].loc[test_index].values.tolist() + + _, new_p_continuous = ttest_ind( + test_values, train_values, nan_policy="omit" + ) # ks_2samp, or ttost_ind from statsmodels.stats.weightstats import ttost_ind + + # Track the minimum p-value + p_continuous = min(p_continuous, new_p_continuous) + + return p_continuous + + +def compute_categorical_p_value( + categorical_labels: Optional[list[str]], + baseline_df: pd.DataFrame, + train_index: list[int], + test_index: list[int], +) -> float: + """ + Compute the minimum p-value for categorical variables between train and test splits. + + Parameters + ---------- + categorical_labels : Optional[List[str]] + List of categorical variable names. + baseline_df : pd.DataFrame + Dataframe containing the baseline data. + train_index : List[int] + Indices for the training set. + test_index : List[int] + + Returns + ------- + float + The minimum p-value across all categorical labels. + """ + + p_categorical = 1 + if categorical_labels: + for label in categorical_labels: + if len(baseline_df[label] != 1): + mapping = { + val: i for i, val in enumerate(np.unique(baseline_df[label])) + } + + tmp_train_values = baseline_df[label].loc[train_index].values.tolist() + tmp_test_values = baseline_df[label].loc[test_index].values.tolist() + + train_values = [mapping[val] for val in tmp_train_values] + test_values = [mapping[val] for val in tmp_test_values] + + new_p_categorical = _chi2_test(test_values, train_values) + + # Track the minimum p-value + p_categorical = min(p_categorical, new_p_categorical) + + return p_categorical + + +def write_continuous_stats( + tsv_path: Path, + continuous_labels: Optional[list[str]], + test_df: pd.DataFrame, + train_df: pd.DataFrame, + subset_name: str, +): + """ + Write continuous statistics (mean, std) to a TSV file. + + Parameters + ---------- + tsv_path : Path + Path to save the output TSV file. + continuous_labels : Optional[List[str]] + List of continuous variable names. + test_df : pd.DataFrame + Test dataset. + train_df : pd.DataFrame + Train dataset. + subset_name : str + Name of the test subset. + """ + + if not continuous_labels: + return + + data = [ + (label, "mean", train_df[label].mean(), test_df[label].mean()) + for label in continuous_labels + ] + [ + (label, "std", train_df[label].std(), test_df[label].std()) + for label in continuous_labels + ] + + df_stats_continuous = pd.DataFrame( + data, columns=["label", "statistic", "train", subset_name] + ) + df_stats_continuous.to_csv(tsv_path, sep="\t", index=False) + + +def write_categorical_stats( + tsv_path: Path, + categorical_labels: Optional[list[str]], + test_df: pd.DataFrame, + train_df: pd.DataFrame, + baseline_df: pd.DataFrame, + subset_name: str, +): + """ + Write categorical statistics (proportion, count) to a TSV file. + + Parameters + ---------- + tsv_path : Path + Path to save the output TSV file. + categorical_labels : Optional[List[str]] + List of categorical variable names. + test_df : pd.DataFrame + Test dataset. + train_df : pd.DataFrame + Train dataset. + baseline_df : pd.DataFrame + Baseline dataset (reference for all unique values). + subset_name : str + Name + + """ + + if not categorical_labels: + return + + data = [] + for label in categorical_labels: + unique_values = baseline_df[label].unique() + for val in unique_values: + test_count = (test_df[label] == val).sum() + train_count = (train_df[label] == val).sum() + + test_proportion = test_count / len(test_df) + train_proportion = train_count / len(train_df) + + data.append((label, val, "proportion", train_proportion, test_proportion)) + data.append((label, val, "count", train_count, test_count)) + + df_stats_categorical = pd.DataFrame( + data, columns=["label", "value", "statistic", "train", subset_name] + ) + df_stats_categorical.to_csv(tsv_path, sep="\t", index=False) diff --git a/clinicadl/splitter/make_splits/utils.py b/clinicadl/splitter/make_splits/utils.py new file mode 100644 index 000000000..687066790 --- /dev/null +++ b/clinicadl/splitter/make_splits/utils.py @@ -0,0 +1,78 @@ +from pathlib import Path +from typing import Optional + +import pandas as pd + +from clinicadl.tsvtools.tsvtools_utils import retrieve_longitudinal + + +def _write_to_csv(df: pd.DataFrame, file_path: Path) -> None: + """ + Save a DataFrame to a TSV file, ensuring the file does not already exist. + + Parameters + ---------- + df : pd.DataFrame + DataFrame to save. + file_path : Path + Path to the destination TSV file. + + Raises + ------ + FileExistsError + If the file already exists at the specified path. + """ + if file_path.exists(): + raise FileExistsError( + f"File {file_path} already exists. Operation aborted to prevent overwriting." + ) + + # Reset index for consistency and save as a TSV file + df.reset_index(drop=True, inplace=True) + df.to_csv(file_path, sep="\t", index=False) + + +def write_to_csv( + df: pd.DataFrame, + split_dir: Path, + all_df: Optional[pd.DataFrame] = None, + subset_name: str = "train", + longitudinal: bool = True, +) -> None: + """ + Save baseline and longitudinal splits of a DataFrame to TSV files. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the subset (e.g., train/test/validation) to save. + split_dir : Path + Directory where the TSV files will be saved. + all_df : Optional[pd.DataFrame], optional + Full dataset including all sessions, used to retrieve longitudinal data. + subset_name : str, default="train" + Name of the subset (e.g., "train", "test", etc.) used in the output filenames. + longitudinal : bool, default=True + Whether to generate and save the longitudinal data subset. + + Raises + ------ + FileExistsError + If any of the output files already exist in the specified directory. + ValueError + If `longitudinal` is True but `all_df` is None, as longitudinal data cannot be generated. + """ + # Save the baseline data + baseline_file = split_dir / f"{subset_name}_baseline.tsv" + _write_to_csv(df, baseline_file) + + if longitudinal: + if all_df is None: + raise ValueError( + "The full dataset (`all_df`) must be provided to generate longitudinal data." + ) + + # Retrieve longitudinal data and save it + longitudinal_file = split_dir / f"{subset_name}.tsv" + long_df = retrieve_longitudinal(df, all_df) + _write_to_csv(long_df, longitudinal_file) diff --git a/clinicadl/splitter/old_splitter.py b/clinicadl/splitter/old_splitter.py deleted file mode 100644 index d39b14a5b..000000000 --- a/clinicadl/splitter/old_splitter.py +++ /dev/null @@ -1,237 +0,0 @@ -import abc -import shutil -from logging import getLogger -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import pandas as pd - -from clinicadl.splitter.config import SplitterConfig -from clinicadl.utils import cluster -from clinicadl.utils.exceptions import MAPSError - -logger = getLogger("clinicadl.split_manager") - - -class Splitter: - def __init__( - self, - config: SplitterConfig, - # split_list: Optional[List[int]] = None, - ): - """_summary_ - - Parameters - ---------- - data_config : DataConfig - _description_ - validation_config : ValidationConfig - _description_ - split_list : Optional[List[int]] (optional, default=None) - _description_ - - """ - self.config = config - # self.config.split.split = split_list - - # self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? - - def max_length(self) -> int: - """Maximum number of splits""" - return self.config.split.n_splits - - def __len__(self): - if not self.config.split.split: - return self.config.split.n_splits - else: - return len(self.config.split.split) - - @property - def allowed_splits_list(self): - """ - List of possible splits if no restriction was applied - - Returns: - list[int]: list of all possible splits - """ - return [i for i in range(self.config.split.n_splits)] - - def __getitem__(self, item) -> Dict: - """ - Returns a dictionary of DataFrames with train and validation data. - - Args: - item (int): Index of the split wanted. - Returns: - Dict[str:pd.DataFrame]: dictionary with two keys (train and validation). - """ - self._check_item(item) - - if self.config.data.multi_cohort: - tsv_df = pd.read_csv(self.config.split.tsv_path, sep="\t") - train_df = pd.DataFrame() - valid_df = pd.DataFrame() - found_diagnoses = set() - for idx in range(len(tsv_df)): - cohort_name = tsv_df.at[idx, "cohort"] - cohort_path = Path(tsv_df.at[idx, "path"]) - cohort_diagnoses = ( - tsv_df.at[idx, "diagnoses"].replace(" ", "").split(",") - ) - if bool(set(cohort_diagnoses) & set(self.config.data.diagnoses)): - target_diagnoses = list( - set(cohort_diagnoses) & set(self.config.data.diagnoses) - ) - - cohort_train_df, cohort_valid_df = self.concatenate_diagnoses( - item, cohort_path=cohort_path, cohort_diagnoses=target_diagnoses - ) - cohort_train_df["cohort"] = cohort_name - cohort_valid_df["cohort"] = cohort_name - train_df = pd.concat([train_df, cohort_train_df]) - valid_df = pd.concat([valid_df, cohort_valid_df]) - found_diagnoses = found_diagnoses | ( - set(cohort_diagnoses) & set(self.config.data.diagnoses) - ) - - if found_diagnoses != set(self.config.data.diagnoses): - raise ValueError( - f"The diagnoses found in the multi cohort dataset {found_diagnoses} " - f"do not correspond to the diagnoses wanted {set(self.config.data.diagnoses)}." - ) - train_df.reset_index(inplace=True, drop=True) - valid_df.reset_index(inplace=True, drop=True) - else: - train_df, valid_df = self.concatenate_diagnoses(item) - train_df["cohort"] = "single" - valid_df["cohort"] = "single" - - return { - "train": train_df, - "validation": valid_df, - } - - @staticmethod - def get_dataframe_from_tsv_path(tsv_path: Path) -> pd.DataFrame: - df = pd.read_csv(tsv_path, sep="\t") - list_columns = df.columns.values - - if ( - "diagnosis" not in list_columns - # or "age" not in list_columns - # or "sex" not in list_columns - ): - parents_path = tsv_path.resolve().parent - labels_path = parents_path / "labels.tsv" - while ( - not labels_path.is_file() - and ((parents_path / "kfold.json").is_file()) - or (parents_path / "split.json").is_file() - ): - parents_path = parents_path.parent - try: - labels_df = pd.read_csv(labels_path, sep="\t") - df = pd.merge( - df, - labels_df, - how="inner", - on=["participant_id", "session_id"], - ) - except Exception: - pass - return df - - @staticmethod - def load_data(tsv_path: Path, cohort_diagnoses: List[str]) -> pd.DataFrame: - df = Splitter.get_dataframe_from_tsv_path(tsv_path) - df = df[df.diagnosis.isin((cohort_diagnoses))] - df.reset_index(inplace=True, drop=True) - return df - - def concatenate_diagnoses( - self, - split, - cohort_path: Optional[Path] = None, - cohort_diagnoses: Optional[List[str]] = None, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Concatenated the diagnoses needed to form the train and validation sets.""" - - if cohort_diagnoses is None: - cohort_diagnoses = list(self.config.data.diagnoses) - - tmp_cohort_path = ( - cohort_path if cohort_path is not None else self.config.split.tsv_path - ) - train_path, valid_path = self._get_tsv_paths( - tmp_cohort_path, - split, - ) - - logger.debug(f"Training data loaded at {train_path}") - if self.config.data.baseline: - train_path = train_path / "train_baseline.tsv" - else: - train_path = train_path / "train.tsv" - train_df = self.load_data(train_path, cohort_diagnoses) - - logger.debug(f"Validation data loaded at {valid_path}") - if self.config.validation.valid_longitudinal: - valid_path = valid_path / "validation.tsv" - else: - valid_path = valid_path / "validation_baseline.tsv" - valid_df = self.load_data(valid_path, cohort_diagnoses) - - return train_df, valid_df - - def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: - """ - Computes the paths to the TSV files needed depending on the split structure. - - Args: - cohort_path (str): path to the split structure of a cohort. - split (int): Index of the split. - Returns: - train_path (str): path to the directory containing training data. - valid_path (str): path to the directory containing validation data. - """ - if args is not None: - for split in args: - train_path = cohort_path / f"split-{split}" - valid_path = cohort_path / f"split-{split}" - return train_path, valid_path - else: - train_path = cohort_path - valid_path = cohort_path - return train_path, valid_path - - def split_iterator(self): - """Returns an iterable to iterate on all splits wanted.""" - - if not self.config.split.split: - return range(self.config.split.n_splits) - else: - return self.config.split.split - - def _check_item(self, item): - if item not in self.allowed_splits_list: - raise IndexError( - f"Split index {item} out of allowed splits {self.allowed_splits_list}." - ) - - def check_split_list(self, maps_path, overwrite): - existing_splits = [] - for split in self.split_iterator(): - split_path = maps_path / f"split-{split}" - if split_path.is_dir(): - if overwrite: - if cluster.master: - shutil.rmtree(split_path) - else: - existing_splits.append(split) - - if len(existing_splits) > 0: - raise MAPSError( - f"Splits {existing_splits} already exist. Please " - f"specify a list of splits not intersecting the previous list, " - f"or use overwrite to erase previously trained splits." - ) diff --git a/clinicadl/splitter/split.py b/clinicadl/splitter/split.py index 72c4f9d82..ff9ec21f8 100644 --- a/clinicadl/splitter/split.py +++ b/clinicadl/splitter/split.py @@ -1,18 +1,165 @@ from pathlib import Path +from typing import Optional -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.splitter.kfold import Split +from pydantic import NonNegativeInt +from torch.utils.data import DataLoader +from clinicadl.dataset.dataloader import DataLoaderConfig +from clinicadl.dataset.dataloader.defaults import ( + BATCH_SIZE, + DP_DEGREE, + DROP_LAST, + NUM_WORKERS, + PIN_MEMORY, + PREFETCH_FACTOR, + RANK, + SAMPLING_WEIGHTS, + SHUFFLE, +) +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.utils.config import ClinicaDLConfig -def split_tsv(sub_ses_tsv: Path) -> Path: - """TO COMPLETE""" - split_dir = Path("") - return split_dir +class Split(ClinicaDLConfig): + """Dataclass that contains all the useful info on the split.""" + index: NonNegativeInt + split_dir: Path + train_dataset: CapsDataset + val_dataset: CapsDataset + train_loader: Optional[DataLoader] = None + val_loader: Optional[DataLoader] = None + train_loader_config: Optional[DataLoaderConfig] = None + val_loader_config: Optional[DataLoaderConfig] = None -def get_single_split( - n_subject_validation: int, caps_dataset: CapsDataset, manager: ExperimentManager -) -> Split: - pass + def build_train_loader( + self, + dataloader_config: Optional[DataLoaderConfig] = None, + *, + batch_size: int = BATCH_SIZE, + sampling_weights: Optional[str] = SAMPLING_WEIGHTS, + shuffle: bool = SHUFFLE, + drop_last: bool = DROP_LAST, + num_workers: int = NUM_WORKERS, + prefetch_factor: Optional[int] = PREFETCH_FACTOR, + pin_memory: bool = PIN_MEMORY, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> None: + """ + Build a train loader for the split. + + Parameters + ---------- + dataloader_config : Optional[DataLoaderConfig] (optional, default=None) + Pre-configured DataLoader configuration. + batch_size : int (optional, default=1) + Batch size for the DataLoader (used if `dataloader_config` is not provided). + sampling_weights : Optional[str] (optional, default=None) + Name of the column in the dataframe of the CapsDatasets where to find the sampling + weights (used if `dataloader_config` is not provided). + shuffle : bool (optional, default=False) + Whether to shuffle the data (used if `dataloader_config` is not provided). + drop_last : bool (optional, default=False) + Whether to drop the last incomplete batch (used if `dataloader_config` is not provided). + num_workers : int (optional, default=0) + Number of workers for data loading (used if `dataloader_config` is not provided). + prefetch_factor : Optional[int] (optional, default=None) + Prefetch factor if num_workers is not 0 (used if `dataloader_config` is not provided). + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Raises + ------ + ValueError + If one of 'dp_degree' and 'rank' is None but the other is not None. + KeyError + If 'sampling_weights' is passed but there is no such column in the dataframe + of the train dataset. + """ + if dataloader_config: + self.train_loader_config = dataloader_config + else: + self.train_loader_config = DataLoaderConfig( + batch_size=batch_size, + sampling_weights=sampling_weights, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.train_loader = self.train_loader_config.get_dataloader( + dataset=self.train_dataset, + dp_degree=dp_degree, + rank=rank, + ) + + def build_val_loader( + self, + dataloader_config: Optional[DataLoaderConfig] = None, + *, + batch_size: int = BATCH_SIZE, + sampling_weights: Optional[str] = SAMPLING_WEIGHTS, + shuffle: bool = SHUFFLE, + drop_last: bool = DROP_LAST, + num_workers: int = NUM_WORKERS, + prefetch_factor: Optional[int] = PREFETCH_FACTOR, + pin_memory: bool = PIN_MEMORY, + dp_degree: Optional[int] = DP_DEGREE, + rank: Optional[int] = RANK, + ) -> None: + """ + Build a validation loader for the split. + + Parameters + ---------- + dataloader_config : Optional[DataLoaderConfig] (optional, default=None) + Pre-configured DataLoader configuration. + batch_size : int (optional, default=1) + Batch size for the DataLoader (used if `dataloader_config` is not provided). + sampling_weights : Optional[str] (optional, default=None) + Name of the column in the dataframe of the CapsDatasets where to find the sampling + weights (used if `dataloader_config` is not provided). + shuffle : bool (optional, default=False) + Whether to shuffle the data (used if `dataloader_config` is not provided). + drop_last : bool (optional, default=False) + Whether to drop the last incomplete batch (used if `dataloader_config` is not provided). + num_workers : int (optional, default=0) + Number of workers for data loading (used if `dataloader_config` is not provided). + prefetch_factor : Optional[int] (optional, default=None) + Prefetch factor if num_workers is not 0 (used if `dataloader_config` is not provided). + dp_degree : Optional[int] (optional, default=None) + The degree of data parallelism. None if no data parallelism. + rank : Optional[int] (optional, default=None) + Process id within the data parallelism communicator. + None if no data parallelism. + + Raises + ------ + ValueError + If one of 'dp_degree' and 'rank' is None but the other is not None. + KeyError + If 'sampling_weights' is passed but there is no such column in the dataframe + of the validation dataset. + """ + if dataloader_config: + self.val_loader_config = dataloader_config + else: + self.val_loader_config = DataLoaderConfig( + batch_size=batch_size, + sampling_weights=sampling_weights, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + self.val_loader = self.val_loader_config.get_dataloader( + dataset=self.val_dataset, + dp_degree=dp_degree, + rank=rank, + ) diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index 3e0f09388..42a6c6a49 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -1,6 +1,119 @@ +from copy import copy from pathlib import Path from typing import List +import numpy as np +import pandas as pd + +from clinicadl.tsvtools.tsvtools_utils import first_session +from clinicadl.utils.exceptions import ClinicaDLTSVError + + +def extract_baseline(diagnosis_df, set_index=True): + from copy import deepcopy + + if set_index: + all_df = deepcopy(diagnosis_df) + all_df.set_index(["participant_id", "session_id"], inplace=True) + else: + all_df = deepcopy(diagnosis_df) + + result_df = pd.DataFrame() + for subject, subject_df in all_df.groupby(level=0): + if subject != "participant_id": + baseline = first_session(subject_df) + + subject_baseline_df = pd.DataFrame( + data=[ + [subject, baseline] + subject_df.loc[(subject, baseline)].tolist() + ], + columns=["participant_id", "session_id"] + + subject_df.columns.values.tolist(), + ) + result_df = pd.concat([result_df, subject_baseline_df]) + + result_df.reset_index(inplace=True, drop=True) + return result_df + + +def chi2(x_test, x_train): + from scipy.stats import chisquare + + # Look for chi2 computation + total_categories = np.concatenate([x_test, x_train]) + unique_categories = np.unique(total_categories) + f_obs = [(x_test == category).sum() / len(x_test) for category in unique_categories] + f_exp = [ + (x_train == category).sum() / len(x_train) for category in unique_categories + ] + T, p = chisquare(f_obs, f_exp) + + return T, p + + +def add_demographics(df, demographics_df, diagnosis) -> pd.DataFrame: + out_df = pd.DataFrame() + tmp_demo_df = copy(demographics_df) + tmp_demo_df.reset_index(inplace=True) + for idx in df.index.values: + participant = df.loc[idx, "participant_id"] + session = df.loc[idx, "session_id"] + row_df = tmp_demo_df[ + (tmp_demo_df.participant_id == participant) + & (tmp_demo_df.session_id == session) + ] + out_df = pd.concat([out_df, row_df]) + out_df.reset_index(inplace=True, drop=True) + out_df.diagnosis = [diagnosis] * len(out_df) + return out_df + + +def remove_unicity(values_list): + """Count the values of each class and label all the classes with only one label under the same label.""" + unique_classes, counts = np.unique(values_list, return_counts=True) + one_sub_classes = unique_classes[(counts == 1)] + for class_element in one_sub_classes: + values_list[values_list.index(class_element)] = unique_classes.min() + + return values_list + + +def category_conversion(values_list) -> List[int]: + values_np = np.array(values_list) + unique_classes = np.unique(values_np) + for index, unique_class in enumerate(unique_classes): + values_np[values_np == unique_class] = index + 1 + + return values_np.astype(int).tolist() + + +def find_label(labels_list, target_label): + if target_label in labels_list: + return target_label + else: + min_length = np.inf + found_label = None + for label in labels_list: + if target_label.lower() in label.lower() and min_length > len(label): + min_length = len(label) + found_label = label + if found_label is None: + raise ClinicaDLTSVError( + f"No label was found in {labels_list} for target label {target_label}." + ) + + return found_label + + +def retrieve_longitudinal(df, diagnosis_df): + final_df = pd.DataFrame() + for idx in df.index.values: + subject = df.loc[idx, "participant_id"] + row_df = diagnosis_df[diagnosis_df.participant_id == subject] + final_df = pd.concat([final_df, row_df]) + + return final_df + def find_splits(maps_path: Path) -> List[int]: """Find which splits that were trained in the MAPS.""" diff --git a/clinicadl/splitter/splitter/__init__.py b/clinicadl/splitter/splitter/__init__.py new file mode 100644 index 000000000..f5eec37c5 --- /dev/null +++ b/clinicadl/splitter/splitter/__init__.py @@ -0,0 +1,2 @@ +from .kfold import KFold +from .single_split import SingleSplit diff --git a/clinicadl/splitter/splitter/kfold.py b/clinicadl/splitter/splitter/kfold.py new file mode 100644 index 000000000..d4a306508 --- /dev/null +++ b/clinicadl/splitter/splitter/kfold.py @@ -0,0 +1,108 @@ +from pathlib import Path +from typing import Generator, List, Optional, Sequence, Union + +from pydantic import PositiveInt + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) + + +class KFoldConfig(SplitterConfig): + """ + Configuration for K-Fold cross-validation splits. + """ + + json_name: str = "kfold_config.json" + subset_name: str = "validation" + n_splits: PositiveInt = 5 + stratification: Union[str, bool] = False + + @property + def pattern(self): + return f"{self.n_splits}_fold" + + +class KFold(Splitter): + """ + Handles K-Fold cross-validation with optional stratification and demographic balancing. + Allows saving, reading, and iterating over splits for reproducibility. + """ + + def __init__(self, split_dir: Path): + """ + Initialize KFold with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + + super().__init__(split_dir=split_dir) + + def _init_config(self, **args): + self.config: KFoldConfig = KFoldConfig(**args) + + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + return [ + self._read_split(self.split_dir / f"split-{i}") + for i in range(self.config.n_splits) + ] + + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Generator[Split, None, None]: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + if not self.config or self.config.n_splits is None: + raise ValueError( + "No splits found, you must first run the function 'make_splits' to split your dataset, " + "or make sure you use a working split dir ." + ) + + self.check_dataset_and_tsv_consistency(dataset) + + if splits is None: + splits = list(range(self.config.n_splits)) + + for split in splits: + if split not in range(self.config.n_splits): + raise ValueError( + f"Split-{split} doesn't exist. There are {self.config.n_splits} splits, numbered from 0 to {self.config.n_splits-1}." + ) + yield self._get_split(split_id=split, dataset=dataset) diff --git a/clinicadl/splitter/splitter/single_split.py b/clinicadl/splitter/splitter/single_split.py new file mode 100644 index 000000000..daee8fd9f --- /dev/null +++ b/clinicadl/splitter/splitter/single_split.py @@ -0,0 +1,96 @@ +from pathlib import Path +from typing import List, Optional, Sequence, Union + +from pydantic import PositiveInt, field_validator + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) + + +class SingleSplitConfig(SplitterConfig): + json_name: str = "single_split_config.json" + subset_name: str = "test" + stratification: Union[List[str], bool] = False + n_test: PositiveInt = 100 + p_categorical_threshold: float = 0.80 + p_continuous_threshold: float = 0.80 + + @property + def pattern(self) -> str: + return "split" + + @field_validator("p_categorical_threshold", "p_continuous_threshold", mode="before") + @classmethod + def validate_thresholds(cls, value: Union[float, int]) -> float: + if not (0 <= value <= 1): + raise ValueError(f"Threshold must be between 0 and 1, got {value}") + return value + + +class SingleSplit(Splitter): + def __init__(self, split_dir: Path): + """ + Initialize Split with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + super().__init__(split_dir=split_dir) + + def _init_config(self, **args): + self.config = SingleSplitConfig(**args) + + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + return [self._read_split(self.split_dir)] + + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Split: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + if not self.config: + raise ValueError( + "No splits found, you must first run the function 'make_splits' to split your dataset, " + "or make sure you use a working split dir ." + ) + + self.check_dataset_and_tsv_consistency(dataset) + + return self._get_split(dataset) diff --git a/clinicadl/splitter/splitter/splitter.py b/clinicadl/splitter/splitter/splitter.py new file mode 100644 index 000000000..ae7f65ca0 --- /dev/null +++ b/clinicadl/splitter/splitter/splitter.py @@ -0,0 +1,256 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generator, List, Optional, Sequence, Tuple, Union + +import pandas as pd +from pydantic import ( + computed_field, + field_validator, +) + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.splitter.split import Split +from clinicadl.utils.config import ClinicaDLConfig +from clinicadl.utils.exceptions import ClinicaDLTSVError +from clinicadl.utils.iotools.utils import path_encoder + + +class SubjectsSessionsSplit(ClinicaDLConfig): + """ + Dataclass to store train and validation splits for subjects and sessions. + """ + + train: pd.DataFrame + validation: pd.DataFrame + + @computed_field + @property + def train_val_df(self) -> pd.DataFrame: + return pd.concat([self.train, self.validation], ignore_index=True) + + +class SplitterConfig(ClinicaDLConfig): + json_name: str + split_dir: Path + subset_name: str + stratification: Union[str, List[str], bool] = False + valid_longitudinal: bool = False + + @field_validator("split_dir", mode="after") + @classmethod + def validate_split_dir(cls, v): + if not isinstance(v, Path): + v = Path(v) + if v and not v.is_dir(): + v.mkdir(parents=True, exist_ok=True) + return v + + def _check_split_dir(self): + split_numero = 1 + folder_name = self.pattern + while (self.split_dir / folder_name).is_dir(): + split_numero += 1 + folder_name = f"{self.pattern}_{split_numero}" + + self.split_dir = self.split_dir / folder_name + + @property + @abstractmethod + def pattern(self) -> str: + pass + + def _write_json(self) -> None: + """ + Save KFold configuration to JSON. + """ + if not self.split_dir: + raise ValueError( + "No split directory specified, use the method 'write' to save your splits." + ) + + out_json_file = self.split_dir / self.json_name + if out_json_file.is_file(): + raise FileExistsError( + f"File {out_json_file} already exists, your splits may have already been written." + ) + + with out_json_file.open(mode="w") as json_file: + json.dump( + self.model_dump(), + json_file, + skipkeys=True, + indent=4, + default=path_encoder, + ) + + +class Splitter(ABC): + def __init__(self, split_dir: Path): + """ + Initialize Split with a dataset. + + Parameters + ---------- + dataset : CapsDataset + Dataset to split for cross-validation. + """ + split_dir = Path(split_dir) + + if not split_dir.is_dir(): + raise FileNotFoundError(f"No such directory: {split_dir}") + + self.split_dir = split_dir + self._init_config(**self._read_json()) + self.subjects_sessions_split = self._read_splits() + + @abstractmethod + def _init_config(self, **args): + self.config: SplitterConfig + + def _read_json(self): + """ + Load KFold configuration from a JSON file. + + Parameters + ---------- + split_dir : Path + Directory containing the JSON configuration file. + + Returns + ------- + KFoldConfig + The configuration object loaded from the JSON file. + """ + + json_file = [json for json in self.split_dir.glob("*.json")] + + if len(json_file) > 1: + raise ValueError( + f"Multiple JSON files found in {self.split_dir}, please remove or rename them." + ) + + elif len(json_file) == 0: + raise FileNotFoundError(f"No JSON file found in {self.split_dir}") + + if not json_file[0].is_file(): + raise FileNotFoundError(f"No such file: {json_file}") + + with json_file[0].open(mode="r") as file: + dict_ = json.load(file) + + return dict_ + + def _read_split(self, split_path: Path) -> SubjectsSessionsSplit: + """ + Load a single split's train and validation sets from files. + + Parameters + ---------- + split_dir : Path + Directory containing split data. + split_number : int + The split index to load. + + Returns + ------- + SubjectsSessionsSplit + Object containing train and validation sets as DataFrames. + """ + + if not split_path.is_dir(): + raise FileNotFoundError(f"No such directory: {split_path}") + + try: + train = pd.read_csv(split_path / "train_baseline.tsv", sep="\t") + validation = pd.read_csv( + split_path / f"{self.config.subset_name}_baseline.tsv", sep="\t" + ) # type: ignore + + except FileNotFoundError as exc: + raise FileNotFoundError( + f"One or more of the required files are missing: 'train_baseline.tsv', '{self.config.subset_name}_baseline.tsv'" + ) from exc # type: ignore + + return SubjectsSessionsSplit( + train=train, + validation=validation, + ) + + @abstractmethod + def _read_splits(self) -> List[SubjectsSessionsSplit]: + """ + Load all splits and configuration from a directory. + + Parameters + ---------- + split_dir : Path + Directory containing the splits and configuration JSON file. + + Returns + ------- + None + Populates `subjects_sessions_split` and `config` attributes. + """ + + def check_dataset_and_tsv_consistency(self, dataset: CapsDataset): + df1 = self.subjects_sessions_split[0].train_val_df + df2 = dataset.df + pairs_df1 = set(zip(df1["participant_id"], df1["session_id"])) + pairs_df2 = set(zip(df2["participant_id"], df2["session_id"])) + + # Vérification que toutes les paires de df1 sont dans df2 + if not pairs_df1.issubset(pairs_df2): + raise ClinicaDLTSVError( + "Not all pairs of participants and sessions from the TSV file are present in the dataset." + "Please check the TSV file and make sure all participants and sessions are unique." + ) + + @abstractmethod + def get_splits( + self, dataset: CapsDataset, splits: Optional[Sequence[int]] = None + ) -> Union[Split, Generator[Split, None, None]]: + """ + Yield dataset splits by their indices. + + Parameters + ---------- + splits : Sequence[int] + Indices of the splits to retrieve. + + Yields + ------ + Split + The train and validation datasets for each requested split. + + Raises + ------ + ValueError + If the requested split indices are out of range or no splits are available. + """ + + def _get_split( + self, + dataset: CapsDataset, + split_id: int = 0, + ) -> Split: + """ + Retrieve a single dataset split. + + Parameters + ---------- + split_id : int + Index of the split to retrieve. + + Returns + ------- + Split + Object containing train and validation datasets for the specified split. + """ + subjects_sessions = self.subjects_sessions_split[split_id] + return Split( + index=split_id, + split_dir=self.split_dir, + train_dataset=dataset.subset(subjects_sessions.train), + val_dataset=dataset.subset(subjects_sessions.validation), + ) diff --git a/clinicadl/splitter/test.py b/clinicadl/splitter/test.py new file mode 100644 index 000000000..83010184e --- /dev/null +++ b/clinicadl/splitter/test.py @@ -0,0 +1,168 @@ +from pathlib import Path + +import pandas as pd +import torchio.transforms as transforms + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.preprocessing import PreprocessingT1 +from clinicadl.splitter import make_kfold, make_split +from clinicadl.splitter.dataloader import DataLoaderConfig +from clinicadl.splitter.splitter import KFold, SingleSplit +from clinicadl.transforms.extraction import Image, Patch, Slice +from clinicadl.transforms.transforms import Transforms +from clinicadl.tsvtools.get_metadata.get_metadata import get_metadata + +# maps_path = Path("/") +# manager = ExperimentManager(maps_path, overwrite=False) + + +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +sub_ses_all = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects.tsv") + +# df = get_metadata(sub_ses_t1, sub_ses_all) + +splir_dir = make_split( + sub_ses_t1, + output_dir=Path( + "/Users/camille.brianceau/aramis/CLINICADL/clinicadl/tests/unittests/ressources/caps_example/split_test" + ), + subset_name="test", + stratification=["age", "sex", "test", "diagnosis"], + n_test=0.2, +) +print(splir_dir) + + +train_path = splir_dir / "train_baseline.tsv" +test_path = splir_dir / "test_baseline.tsv" + +train_df = pd.read_csv(train_path, sep="\t") +test_df = pd.read_csv(test_path, sep="\t") + +print("train age mean", train_df["age"].mean()) +print("test age mean", test_df["age"].mean()) +print("/n") +print("train age std", train_df["age"].std()) +print("test age std", test_df["age"].std()) +print("/n") +print("train test mean", train_df["test"].mean()) +print("test test mean", test_df["test"].mean()) +print("/n") +print("train test std", train_df["test"].std()) +print("test test std", test_df["test"].std()) + +print("/n") +print( + "train diagnosis count AD", + len(train_df[train_df["diagnosis"] == "AD"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "AD"]) / len(train_df), +) +print( + "test diagnosis count AD", + len(test_df[test_df["diagnosis"] == "AD"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "AD"]) / len(test_df), +) +print("/n") +print( + "train diagnosis count MCI", + len(train_df[train_df["diagnosis"] == "MCI"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "MCI"]) / len(train_df), +) +print( + "test diagnosis count MCI", + len(test_df[test_df["diagnosis"] == "MCI"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "MCI"]) / len(test_df), +) +print("/n") +print( + "train diagnosis count CN", + len(train_df[train_df["diagnosis"] == "CN"]), + "/", + len(train_df), + len(train_df[train_df["diagnosis"] == "CN"]) / len(train_df), +) +print( + "test diagnosis count CN", + len(test_df[test_df["diagnosis"] == "CN"]), + "/", + len(test_df), + len(test_df[test_df["diagnosis"] == "CN"]) / len(test_df), +) + +print("/n") +print( + "train sex count F", + len(train_df[train_df["sex"] == "F"]), + "/", + len(train_df), + len(train_df[train_df["sex"] == "F"]) / len(train_df), +) +print( + "test sex count F", + len(test_df[test_df["sex"] == "F"]), + "/", + len(test_df), + len(test_df[test_df["sex"] == "F"]) / len(test_df), +) +print("/n") +print( + "train sex count M", + len(train_df[train_df["sex"] == "M"]), + "/", + len(train_df), + len(train_df[train_df["sex"] == "M"]) / len(train_df), +) +print( + "test sex count M", + len(test_df[test_df["sex"] == "M"]), + "/", + len(test_df), + len(test_df[test_df["sex"] == "M"]) / len(test_df), +) + + +fold_dir = make_kfold(train_path, stratification="sex", n_splits=2) + +print(fold_dir) + +caps_directory = Path("/Users/camille.brianceau/aramis/CLINICADL/caps") +preprocessing_t1 = PreprocessingT1() +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=train_path, + preprocessing=preprocessing_t1, + transforms=transforms_image, +) +print(dataset_t1_image.__str__()) +dataset_t1_image.prepare_data(n_proc=2) + +splitter = KFold(fold_dir) + +for split in splitter.get_splits(dataset=dataset_t1_image): + print(f"Split {split.index}:\n") + print(f"Train dataset: {split.train_dataset}") + print(f"describe: {split.train_dataset.describe()}") + print(f"elem per image: {split.train_dataset.elem_per_image}") + print("\n") + print(f"Validation dataset: {split.val_dataset}") + print(f"describe: {split.val_dataset.describe()}") + print(f"elem per image: {split.val_dataset.elem_per_image}") + + split.build_train_loader(num_workers=2) + split.build_val_loader(DataLoaderConfig(batch_size=2)) + print("/n") + print(f"Train loader: {split.train_loader}") + print(f"Validation loader: {split.val_loader}") diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index f89818576..538eccecd 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -18,7 +18,7 @@ from clinicadl.optim.config import OptimizationConfig from clinicadl.optim.early_stopping import EarlyStoppingConfig from clinicadl.predictor.validation import ValidationConfig -from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.splitter.splitter import SplitterConfig as SplitConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig @@ -44,7 +44,6 @@ class TrainConfig(BaseModel, ABC): model: NetworkConfig optimization: OptimizationConfig reproducibility: ReproducibilityConfig - split: SplitConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig validation: ValidationConfig @@ -86,7 +85,6 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.model.__dict__.update(config_dict) self.optimization.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) - self.split.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) self.validation.__dict__.update(config_dict) diff --git a/clinicadl/trainer/old_trainer.py b/clinicadl/trainer/old_trainer.py index 4d00c5206..aa3f82c70 100644 --- a/clinicadl/trainer/old_trainer.py +++ b/clinicadl/trainer/old_trainer.py @@ -34,8 +34,9 @@ from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.predictor.old_predictor import Predictor from clinicadl.predictor.config import PredictConfig -from clinicadl.splitter.old_splitter import Splitter -from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter.splitter import Splitter +from clinicadl.splitter.splitter.splitter import SplitterConfig + if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback diff --git a/clinicadl/tsvtools/get_metadata/get_metadata.py b/clinicadl/tsvtools/get_metadata/get_metadata.py index 5ba40a4f9..63cb474f9 100644 --- a/clinicadl/tsvtools/get_metadata/get_metadata.py +++ b/clinicadl/tsvtools/get_metadata/get_metadata.py @@ -67,3 +67,5 @@ def get_metadata( result_df.to_csv(data_tsv, sep="\t") logger.info(f"metadata were added in: {data_tsv}") + + return result_df diff --git a/clinicadl/tsvtools/split/split.py b/clinicadl/tsvtools/split/split.py index 6235ee2e8..836ee48a7 100644 --- a/clinicadl/tsvtools/split/split.py +++ b/clinicadl/tsvtools/split/split.py @@ -25,35 +25,6 @@ logger = getLogger("clinicadl.tsvtools.split") -def KStests(train_df, test_df, threshold=0.5): - pmin = 1 - column = "" - for col in train_df.columns: - if col == "session_id": - continue - _, pval = ks_2samp(train_df[col], test_df[col]) - if pval < pmin: - pmin = pval - column = col - return (pmin, column) - - -def shuffle_choice(df, n_shuffle=10): - p_min_max, n_col_min = 0, df.columns.size - - for i in range(n_shuffle): - train_df = df.sample(frac=0.75) - test_df = df.drop(train_df.index) - - p, col = KStests(train_df, test_df) - - if p > p_min_max: - p_min_max = p - best_train_df, best_test_df = train_df, test_df - - return (best_train_df, best_test_df, p_min_max) - - def KStests(train_df, test_df, threshold=0.5): pmin = 1 column = "" diff --git a/clinicadl/tsvtools/tsvtools_utils.py b/clinicadl/tsvtools/tsvtools_utils.py index caf842b3e..27291940e 100644 --- a/clinicadl/tsvtools/tsvtools_utils.py +++ b/clinicadl/tsvtools/tsvtools_utils.py @@ -3,6 +3,7 @@ from copy import copy from logging import getLogger from pathlib import Path +from typing import List import numpy as np import pandas as pd @@ -153,7 +154,7 @@ def remove_unicity(values_list): return values_list -def category_conversion(values_list): +def category_conversion(values_list) -> List[int]: values_np = np.array(values_list) unique_classes = np.unique(values_np) for index, unique_class in enumerate(unique_classes): diff --git a/tests/unittests/dataset/test_config.py b/tests/unittests/dataset/test_config.py index d7e51d598..4fa9588e9 100644 --- a/tests/unittests/dataset/test_config.py +++ b/tests/unittests/dataset/test_config.py @@ -1,6 +1,6 @@ import pytest -from clinicadl.dataset.config import DataConfig, FileType +from clinicadl.dataset.config import FileType from clinicadl.utils.enum import Preprocessing diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json b/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json new file mode 100644 index 000000000..6bf0d9492 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/kfold_config.json @@ -0,0 +1,8 @@ +{ + "json_name": "kfold_config.json", + "split_dir": "../ressources/caps_example/split_test/split/2_fold", + "subset_name": "validation", + "stratification": "sex", + "valid_longitudinal": false, + "n_splits": 2 +} \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/train_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-0/validation_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv new file mode 100644 index 000000000..eb9b6ff49 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/train_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-005 ses-M006 85 F 4 CN +sub-015 ses-M006 85 F 12 CN +sub-017 ses-M006 41 F 15 CN +sub-025 ses-M006 57 F 20 MCI +sub-027 ses-M006 90 F 23 CN +sub-033 ses-M000 36 M 25 AD +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-047 ses-M006 48 F 26 AD +sub-053 ses-M000 52 M 24 AD +sub-055 ses-M006 73 F 21 MCI +sub-057 ses-M006 77 F 18 AD +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-066 ses-M006 63 M 11 MCI +sub-073 ses-M000 85 M 8 AD +sub-106 ses-M006 72 M 6 AD +sub-107 ses-M006 41 F 7 MCI +sub-126 ses-M006 60 M 22 MCI +sub-127 ses-M006 90 F 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-137 ses-M006 56 F 31 MCI +sub-143 ses-M000 71 M 32 CN +sub-144 ses-M000 64 M 31 AD +sub-146 ses-M006 73 M 27 MCI +sub-153 ses-M000 52 M 24 AD +sub-154 ses-M000 38 M 23 CN +sub-164 ses-M000 86 M 15 AD +sub-167 ses-M006 94 F 10 CN +sub-173 ses-M000 85 M 8 AD +sub-174 ses-M000 34 M 7 CN +sub-175 ses-M006 75 F 5 CN +sub-205 ses-M006 85 F 4 CN +sub-207 ses-M006 41 F 7 MCI +sub-213 ses-M000 40 M 9 CN +sub-214 ses-M000 56 M 10 CN +sub-217 ses-M006 41 F 15 CN +sub-223 ses-M000 69 M 17 AD +sub-224 ses-M000 41 M 18 CN +sub-234 ses-M000 74 M 26 CN +sub-235 ses-M006 85 F 28 MCI +sub-245 ses-M006 92 F 29 MCI +sub-246 ses-M006 73 M 27 MCI +sub-254 ses-M000 38 M 23 CN +sub-256 ses-M006 72 M 19 CN +sub-264 ses-M000 86 M 15 AD +sub-266 ses-M006 63 M 11 MCI +sub-274 ses-M000 34 M 7 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv new file mode 100644 index 000000000..267118404 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/2_fold/split-1/validation_baseline.tsv @@ -0,0 +1,49 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-007 ses-M006 41 F 7 MCI +sub-016 ses-M006 72 M 14 AD +sub-023 ses-M000 69 M 17 AD +sub-036 ses-M006 52 M 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-046 ses-M006 73 M 27 MCI +sub-054 ses-M000 38 M 23 CN +sub-056 ses-M006 72 M 19 CN +sub-067 ses-M006 94 F 10 CN +sub-074 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-077 ses-M006 23 F 2 CN +sub-103 ses-M000 40 M 1 MCI +sub-104 ses-M000 56 M 2 CN +sub-105 ses-M006 85 F 4 CN +sub-115 ses-M006 85 F 12 CN +sub-116 ses-M006 72 M 14 AD +sub-117 ses-M006 41 F 15 CN +sub-123 ses-M000 69 M 17 AD +sub-124 ses-M000 41 M 18 CN +sub-125 ses-M006 57 F 20 MCI +sub-133 ses-M000 36 M 25 AD +sub-134 ses-M000 74 M 26 CN +sub-145 ses-M006 92 F 29 MCI +sub-147 ses-M006 48 F 26 AD +sub-155 ses-M006 73 F 21 MCI +sub-166 ses-M006 63 M 11 MCI +sub-176 ses-M006 49 M 3 AD +sub-177 ses-M006 23 F 2 CN +sub-206 ses-M006 72 M 6 AD +sub-215 ses-M006 85 F 12 CN +sub-216 ses-M006 72 M 14 AD +sub-225 ses-M006 57 F 20 MCI +sub-233 ses-M000 36 M 25 AD +sub-236 ses-M006 52 M 30 MCI +sub-243 ses-M000 71 M 32 CN +sub-244 ses-M000 64 M 31 AD +sub-247 ses-M006 48 F 26 AD +sub-253 ses-M000 52 M 24 AD +sub-255 ses-M006 73 F 21 MCI +sub-263 ses-M000 83 M 16 MCI +sub-267 ses-M006 94 F 10 CN +sub-273 ses-M000 85 M 8 AD +sub-275 ses-M006 75 F 5 CN +sub-276 ses-M006 49 M 3 AD +sub-277 ses-M006 23 F 2 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json b/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json new file mode 100644 index 000000000..a21941ccb --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/single_split_config.json @@ -0,0 +1,15 @@ +{ + "json_name": "single_split_config.json", + "split_dir": "../ressources/caps_example/split_test/split", + "subset_name": "test", + "stratification": [ + "age", + "sex", + "test", + "diagnosis" + ], + "valid_longitudinal": false, + "n_test": 24, + "p_categorical_threshold": 0.5, + "p_continuous_threshold": 0.5 +} \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv b/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv new file mode 100644 index 000000000..2fbf24954 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/split_categorical_stats.tsv @@ -0,0 +1,11 @@ +label value statistic train test +sex M proportion 0.59375 0.625 +sex M count 57.0 15.0 +sex F proportion 0.40625 0.375 +sex F count 39.0 9.0 +diagnosis MCI proportion 0.2916666666666667 0.3333333333333333 +diagnosis MCI count 28.0 8.0 +diagnosis CN proportion 0.40625 0.375 +diagnosis CN count 39.0 9.0 +diagnosis AD proportion 0.3020833333333333 0.2916666666666667 +diagnosis AD count 29.0 7.0 diff --git a/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv b/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv new file mode 100644 index 000000000..6bb4a628d --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/split_continuous_stats.tsv @@ -0,0 +1,5 @@ +label statistic train test +age mean 63.416666666666664 61.083333333333336 +test mean 16.552083333333332 16.291666666666668 +age std 19.392709789569277 14.634434279010128 +test std 9.367210213984869 9.129403841037684 diff --git a/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv new file mode 100644 index 000000000..7370eeaa4 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/test_baseline.tsv @@ -0,0 +1,25 @@ +participant_id session_id age sex test diagnosis +sub-163 ses-M000 83 M 16 MCI +sub-227 ses-M006 90 F 23 CN +sub-006 ses-M006 72 M 6 AD +sub-014 ses-M000 56 M 10 CN +sub-034 ses-M000 74 M 26 CN +sub-065 ses-M006 58 F 13 AD +sub-136 ses-M006 52 M 30 MCI +sub-165 ses-M006 58 F 13 AD +sub-037 ses-M006 56 F 31 MCI +sub-226 ses-M006 60 M 22 MCI +sub-113 ses-M000 40 M 9 CN +sub-076 ses-M006 49 M 3 AD +sub-157 ses-M006 77 F 18 AD +sub-237 ses-M006 56 F 31 MCI +sub-024 ses-M000 41 M 18 CN +sub-035 ses-M006 85 F 28 MCI +sub-026 ses-M006 60 M 22 MCI +sub-204 ses-M000 56 M 2 CN +sub-257 ses-M006 77 F 18 AD +sub-114 ses-M000 56 M 10 CN +sub-156 ses-M006 72 M 19 CN +sub-203 ses-M000 40 M 1 MCI +sub-265 ses-M006 58 F 13 AD +sub-013 ses-M000 40 M 9 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/train.tsv b/tests/unittests/ressources/caps_example/split_test/split/train.tsv new file mode 100644 index 000000000..ac26a8982 --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/train.tsv @@ -0,0 +1,155 @@ +participant_id session_id age sex test diagnosis +sub-275 ses-M006 75 F 5 CN +sub-275 ses-M018 96 M 4 MCI +sub-036 ses-M006 52 M 30 MCI +sub-146 ses-M006 73 M 27 MCI +sub-243 ses-M000 71 M 32 CN +sub-155 ses-M006 73 F 21 MCI +sub-155 ses-M018 48 M 20 CN +sub-164 ses-M000 86 M 15 AD +sub-164 ses-M054 91 F 14 MCI +sub-016 ses-M006 72 M 14 AD +sub-247 ses-M006 48 F 26 AD +sub-247 ses-M036 59 M 25 CN +sub-064 ses-M000 86 M 15 AD +sub-064 ses-M054 91 F 14 MCI +sub-206 ses-M006 72 M 6 AD +sub-053 ses-M000 52 M 24 AD +sub-043 ses-M000 71 M 32 CN +sub-007 ses-M006 41 F 7 MCI +sub-007 ses-M036 59 M 8 CN +sub-167 ses-M006 94 F 10 CN +sub-167 ses-M036 68 M 9 MCI +sub-103 ses-M000 40 M 1 MCI +sub-057 ses-M006 77 F 18 AD +sub-057 ses-M036 99 M 17 MCI +sub-177 ses-M006 23 F 2 CN +sub-177 ses-M036 45 M 1 MCI +sub-117 ses-M006 41 F 15 CN +sub-117 ses-M036 59 M 16 CN +sub-224 ses-M000 41 M 18 CN +sub-224 ses-M054 63 F 19 AD +sub-133 ses-M000 36 M 25 AD +sub-066 ses-M006 63 M 11 MCI +sub-124 ses-M000 41 M 18 CN +sub-124 ses-M054 63 F 19 AD +sub-235 ses-M006 85 F 28 MCI +sub-235 ses-M018 57 M 29 CN +sub-236 ses-M006 52 M 30 MCI +sub-253 ses-M000 52 M 24 AD +sub-143 ses-M000 71 M 32 CN +sub-017 ses-M006 41 F 15 CN +sub-017 ses-M036 59 M 16 CN +sub-213 ses-M000 40 M 9 CN +sub-154 ses-M000 38 M 23 CN +sub-154 ses-M054 47 F 22 MCI +sub-135 ses-M006 85 F 28 MCI +sub-135 ses-M018 57 M 29 CN +sub-223 ses-M000 69 M 17 AD +sub-045 ses-M006 92 F 29 MCI +sub-045 ses-M018 61 M 28 MCI +sub-123 ses-M000 69 M 17 AD +sub-125 ses-M006 57 F 20 MCI +sub-125 ses-M018 75 M 21 AD +sub-044 ses-M000 64 M 31 AD +sub-044 ses-M054 23 F 30 MCI +sub-054 ses-M000 38 M 23 CN +sub-054 ses-M054 47 F 22 MCI +sub-003 ses-M000 40 M 1 MCI +sub-126 ses-M006 60 M 22 MCI +sub-215 ses-M006 85 F 12 CN +sub-215 ses-M018 64 M 13 CN +sub-005 ses-M006 85 F 4 CN +sub-005 ses-M018 64 M 5 AD +sub-153 ses-M000 52 M 24 AD +sub-273 ses-M000 85 M 8 AD +sub-115 ses-M006 85 F 12 CN +sub-115 ses-M018 64 M 13 CN +sub-174 ses-M000 34 M 7 CN +sub-174 ses-M054 98 F 6 MCI +sub-075 ses-M006 75 F 5 CN +sub-075 ses-M018 96 M 4 MCI +sub-144 ses-M000 64 M 31 AD +sub-144 ses-M054 23 F 30 MCI +sub-233 ses-M000 36 M 25 AD +sub-077 ses-M006 23 F 2 CN +sub-077 ses-M036 45 M 1 MCI +sub-116 ses-M006 72 M 14 AD +sub-254 ses-M000 38 M 23 CN +sub-254 ses-M054 47 F 22 MCI +sub-246 ses-M006 73 M 27 MCI +sub-276 ses-M006 49 M 3 AD +sub-015 ses-M006 85 F 12 CN +sub-015 ses-M018 64 M 13 CN +sub-217 ses-M006 41 F 15 CN +sub-217 ses-M036 59 M 16 CN +sub-074 ses-M000 34 M 7 CN +sub-074 ses-M054 98 F 6 MCI +sub-137 ses-M006 56 F 31 MCI +sub-137 ses-M036 65 M 32 MCI +sub-004 ses-M000 56 M 2 CN +sub-004 ses-M054 75 F 3 MCI +sub-055 ses-M006 73 F 21 MCI +sub-055 ses-M018 48 M 20 CN +sub-176 ses-M006 49 M 3 AD +sub-234 ses-M000 74 M 26 CN +sub-234 ses-M054 72 F 27 CN +sub-267 ses-M006 94 F 10 CN +sub-267 ses-M036 68 M 9 MCI +sub-166 ses-M006 63 M 11 MCI +sub-244 ses-M000 64 M 31 AD +sub-244 ses-M054 23 F 30 MCI +sub-175 ses-M006 75 F 5 CN +sub-175 ses-M018 96 M 4 MCI +sub-127 ses-M006 90 F 23 CN +sub-127 ses-M036 59 M 24 CN +sub-214 ses-M000 56 M 10 CN +sub-214 ses-M054 75 F 11 AD +sub-173 ses-M000 85 M 8 AD +sub-255 ses-M006 73 F 21 MCI +sub-255 ses-M018 48 M 20 CN +sub-046 ses-M006 73 M 27 MCI +sub-264 ses-M000 86 M 15 AD +sub-264 ses-M054 91 F 14 MCI +sub-274 ses-M000 34 M 7 CN +sub-274 ses-M054 98 F 6 MCI +sub-256 ses-M006 72 M 19 CN +sub-205 ses-M006 85 F 4 CN +sub-205 ses-M018 64 M 5 AD +sub-145 ses-M006 92 F 29 MCI +sub-145 ses-M018 61 M 28 MCI +sub-104 ses-M000 56 M 2 CN +sub-104 ses-M054 75 F 3 MCI +sub-056 ses-M006 72 M 19 CN +sub-106 ses-M006 72 M 6 AD +sub-033 ses-M000 36 M 25 AD +sub-225 ses-M006 57 F 20 MCI +sub-225 ses-M018 75 M 21 AD +sub-263 ses-M000 83 M 16 MCI +sub-216 ses-M006 72 M 14 AD +sub-245 ses-M006 92 F 29 MCI +sub-245 ses-M018 61 M 28 MCI +sub-073 ses-M000 85 M 8 AD +sub-277 ses-M006 23 F 2 CN +sub-277 ses-M036 45 M 1 MCI +sub-107 ses-M006 41 F 7 MCI +sub-107 ses-M036 59 M 8 CN +sub-067 ses-M006 94 F 10 CN +sub-067 ses-M036 68 M 9 MCI +sub-266 ses-M006 63 M 11 MCI +sub-025 ses-M006 57 F 20 MCI +sub-025 ses-M018 75 M 21 AD +sub-023 ses-M000 69 M 17 AD +sub-063 ses-M000 83 M 16 MCI +sub-047 ses-M006 48 F 26 AD +sub-047 ses-M036 59 M 25 CN +sub-105 ses-M006 85 F 4 CN +sub-105 ses-M018 64 M 5 AD +sub-027 ses-M006 90 F 23 CN +sub-027 ses-M036 59 M 24 CN +sub-134 ses-M000 74 M 26 CN +sub-134 ses-M054 72 F 27 CN +sub-207 ses-M006 41 F 7 MCI +sub-207 ses-M036 59 M 8 CN +sub-147 ses-M006 48 F 26 AD +sub-147 ses-M036 59 M 25 CN diff --git a/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv b/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv new file mode 100644 index 000000000..9188b346b --- /dev/null +++ b/tests/unittests/ressources/caps_example/split_test/split/train_baseline.tsv @@ -0,0 +1,97 @@ +participant_id session_id age sex test diagnosis +sub-275 ses-M006 75 F 5 CN +sub-036 ses-M006 52 M 30 MCI +sub-146 ses-M006 73 M 27 MCI +sub-243 ses-M000 71 M 32 CN +sub-155 ses-M006 73 F 21 MCI +sub-164 ses-M000 86 M 15 AD +sub-016 ses-M006 72 M 14 AD +sub-247 ses-M006 48 F 26 AD +sub-064 ses-M000 86 M 15 AD +sub-206 ses-M006 72 M 6 AD +sub-053 ses-M000 52 M 24 AD +sub-043 ses-M000 71 M 32 CN +sub-007 ses-M006 41 F 7 MCI +sub-167 ses-M006 94 F 10 CN +sub-103 ses-M000 40 M 1 MCI +sub-057 ses-M006 77 F 18 AD +sub-177 ses-M006 23 F 2 CN +sub-117 ses-M006 41 F 15 CN +sub-224 ses-M000 41 M 18 CN +sub-133 ses-M000 36 M 25 AD +sub-066 ses-M006 63 M 11 MCI +sub-124 ses-M000 41 M 18 CN +sub-235 ses-M006 85 F 28 MCI +sub-236 ses-M006 52 M 30 MCI +sub-253 ses-M000 52 M 24 AD +sub-143 ses-M000 71 M 32 CN +sub-017 ses-M006 41 F 15 CN +sub-213 ses-M000 40 M 9 CN +sub-154 ses-M000 38 M 23 CN +sub-135 ses-M006 85 F 28 MCI +sub-223 ses-M000 69 M 17 AD +sub-045 ses-M006 92 F 29 MCI +sub-123 ses-M000 69 M 17 AD +sub-125 ses-M006 57 F 20 MCI +sub-044 ses-M000 64 M 31 AD +sub-054 ses-M000 38 M 23 CN +sub-003 ses-M000 40 M 1 MCI +sub-126 ses-M006 60 M 22 MCI +sub-215 ses-M006 85 F 12 CN +sub-005 ses-M006 85 F 4 CN +sub-153 ses-M000 52 M 24 AD +sub-273 ses-M000 85 M 8 AD +sub-115 ses-M006 85 F 12 CN +sub-174 ses-M000 34 M 7 CN +sub-075 ses-M006 75 F 5 CN +sub-144 ses-M000 64 M 31 AD +sub-233 ses-M000 36 M 25 AD +sub-077 ses-M006 23 F 2 CN +sub-116 ses-M006 72 M 14 AD +sub-254 ses-M000 38 M 23 CN +sub-246 ses-M006 73 M 27 MCI +sub-276 ses-M006 49 M 3 AD +sub-015 ses-M006 85 F 12 CN +sub-217 ses-M006 41 F 15 CN +sub-074 ses-M000 34 M 7 CN +sub-137 ses-M006 56 F 31 MCI +sub-004 ses-M000 56 M 2 CN +sub-055 ses-M006 73 F 21 MCI +sub-176 ses-M006 49 M 3 AD +sub-234 ses-M000 74 M 26 CN +sub-267 ses-M006 94 F 10 CN +sub-166 ses-M006 63 M 11 MCI +sub-244 ses-M000 64 M 31 AD +sub-175 ses-M006 75 F 5 CN +sub-127 ses-M006 90 F 23 CN +sub-214 ses-M000 56 M 10 CN +sub-173 ses-M000 85 M 8 AD +sub-255 ses-M006 73 F 21 MCI +sub-046 ses-M006 73 M 27 MCI +sub-264 ses-M000 86 M 15 AD +sub-274 ses-M000 34 M 7 CN +sub-256 ses-M006 72 M 19 CN +sub-205 ses-M006 85 F 4 CN +sub-145 ses-M006 92 F 29 MCI +sub-104 ses-M000 56 M 2 CN +sub-056 ses-M006 72 M 19 CN +sub-106 ses-M006 72 M 6 AD +sub-033 ses-M000 36 M 25 AD +sub-225 ses-M006 57 F 20 MCI +sub-263 ses-M000 83 M 16 MCI +sub-216 ses-M006 72 M 14 AD +sub-245 ses-M006 92 F 29 MCI +sub-073 ses-M000 85 M 8 AD +sub-277 ses-M006 23 F 2 CN +sub-107 ses-M006 41 F 7 MCI +sub-067 ses-M006 94 F 10 CN +sub-266 ses-M006 63 M 11 MCI +sub-025 ses-M006 57 F 20 MCI +sub-023 ses-M000 69 M 17 AD +sub-063 ses-M000 83 M 16 MCI +sub-047 ses-M006 48 F 26 AD +sub-105 ses-M006 85 F 4 CN +sub-027 ses-M006 90 F 23 CN +sub-134 ses-M000 74 M 26 CN +sub-207 ses-M006 41 F 7 MCI +sub-147 ses-M006 48 F 26 AD diff --git a/tests/unittests/ressources/caps_example/subjects_false.tsv b/tests/unittests/ressources/caps_example/subjects_false.tsv new file mode 100644 index 000000000..136a18edc --- /dev/null +++ b/tests/unittests/ressources/caps_example/subjects_false.tsv @@ -0,0 +1,2 @@ +sub ses age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI \ No newline at end of file diff --git a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv index cd3c1d0b2..e505e36c3 100644 --- a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv +++ b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv @@ -3,11 +3,3 @@ sub-000 ses-M000 sub-000 ses-M006 sub-001 ses-M000 sub-001 ses-M018 -sub-OAS30010 ses-M000 -sub-OAS30011 ses-M000 -sub-OAS30011 ses-M054 -sub-OAS30012 ses-M006 -sub-OAS30012 ses-M018 -sub-OAS30013 ses-M006 -sub-OAS30014 ses-M006 -sub-OAS30014 ses-M036 diff --git a/tests/unittests/ressources/caps_example/subjects_t1.tsv b/tests/unittests/ressources/caps_example/subjects_t1.tsv new file mode 100644 index 000000000..60dfb62dd --- /dev/null +++ b/tests/unittests/ressources/caps_example/subjects_t1.tsv @@ -0,0 +1,65 @@ +participant_id session_id age sex test diagnosis +sub-003 ses-M000 40 M 1 MCI +sub-004 ses-M000 56 M 2 CN +sub-004 ses-M054 75 F 3 MCI +sub-005 ses-M006 85 F 4 CN +sub-005 ses-M018 64 M 5 AD +sub-006 ses-M006 72 M 6 AD +sub-007 ses-M006 41 F 7 MCI +sub-007 ses-M036 59 M 8 CN +sub-013 ses-M000 40 M 9 CN +sub-014 ses-M000 56 M 10 CN +sub-014 ses-M054 75 F 11 AD +sub-015 ses-M006 85 F 12 CN +sub-015 ses-M018 64 M 13 CN +sub-016 ses-M006 72 M 14 AD +sub-017 ses-M006 41 F 15 CN +sub-017 ses-M036 59 M 16 CN +sub-023 ses-M000 69 M 17 AD +sub-024 ses-M000 41 M 18 CN +sub-024 ses-M054 63 F 19 AD +sub-025 ses-M006 57 F 20 MCI +sub-025 ses-M018 75 M 21 AD +sub-026 ses-M006 60 M 22 MCI +sub-027 ses-M006 90 F 23 CN +sub-027 ses-M036 59 M 24 CN +sub-033 ses-M000 36 M 25 AD +sub-034 ses-M000 74 M 26 CN +sub-034 ses-M054 72 F 27 CN +sub-035 ses-M006 85 F 28 MCI +sub-035 ses-M018 57 M 29 CN +sub-036 ses-M006 52 M 30 MCI +sub-037 ses-M006 56 F 31 MCI +sub-037 ses-M036 65 M 32 MCI +sub-043 ses-M000 71 M 32 CN +sub-044 ses-M000 64 M 31 AD +sub-044 ses-M054 23 F 30 MCI +sub-045 ses-M006 92 F 29 MCI +sub-045 ses-M018 61 M 28 MCI +sub-046 ses-M006 73 M 27 MCI +sub-047 ses-M006 48 F 26 AD +sub-047 ses-M036 59 M 25 CN +sub-053 ses-M000 52 M 24 AD +sub-054 ses-M000 38 M 23 CN +sub-054 ses-M054 47 F 22 MCI +sub-055 ses-M006 73 F 21 MCI +sub-055 ses-M018 48 M 20 CN +sub-056 ses-M006 72 M 19 CN +sub-057 ses-M006 77 F 18 AD +sub-057 ses-M036 99 M 17 MCI +sub-063 ses-M000 83 M 16 MCI +sub-064 ses-M000 86 M 15 AD +sub-064 ses-M054 91 F 14 MCI +sub-065 ses-M006 58 F 13 AD +sub-065 ses-M018 74 M 12 MCI +sub-066 ses-M006 63 M 11 MCI +sub-067 ses-M006 94 F 10 CN +sub-067 ses-M036 68 M 9 MCI +sub-073 ses-M000 85 M 8 AD +sub-074 ses-M000 34 M 7 CN +sub-074 ses-M054 98 F 6 MCI +sub-075 ses-M006 75 F 5 CN +sub-075 ses-M018 96 M 4 MCI +sub-076 ses-M006 49 M 3 AD +sub-077 ses-M006 23 F 2 CN +sub-077 ses-M036 45 M 1 MCI diff --git a/tests/unittests/splitter/test_make_split.py b/tests/unittests/splitter/test_make_split.py new file mode 100644 index 000000000..3216f097f --- /dev/null +++ b/tests/unittests/splitter/test_make_split.py @@ -0,0 +1,195 @@ +import json +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError + +from clinicadl.splitter.make_splits import make_kfold, make_split +from clinicadl.utils.exceptions import ( + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) + + +def remove_non_empty_dir(dir_path: Path): + """ + Remove a non-empty directory using only pathlib. + + Parameters + ---------- + dir_path : Path + Path to the directory to remove. + """ + if dir_path.exists() and dir_path.is_dir(): + for item in dir_path.iterdir(): # Iterate through directory contents + if item.is_dir(): + remove_non_empty_dir(item) # Recursively remove subdirectories + else: + item.unlink() # Remove files + dir_path.rmdir() # Remove the now-empty directory + else: + print(f"{dir_path} does not exist or is not a directory.") + + +caps_dir = Path(__file__).parents[1] / "ressources" / "caps_example" + +sub_ses_t1 = caps_dir / "subjects_t1.tsv" +sub_ses_df = pd.read_csv(sub_ses_t1, sep="\t") + +split_dir = caps_dir / "split" +train_path = split_dir / "train.tsv" + + +def test_good_split(): + n_test = 15 + stratification = ["age", "sex", "test", "diagnosis"] + subset_name = "test_test" + + split_dir = make_split( + sub_ses_t1, + output_dir=caps_dir / "test", + subset_name=subset_name, + stratification=stratification, + n_test=n_test, + ) + + train_path = split_dir / "train_baseline.tsv" + test_path = split_dir / f"{subset_name}_baseline.tsv" + + assert train_path.exists() + assert test_path.exists() + + assert (split_dir / "single_split_config.json").is_file + with (split_dir / "single_split_config.json").open(mode="r") as file: + dict_ = json.load(file) + + assert dict_["json_name"] == "single_split_config.json" + assert dict_["split_dir"] == str(split_dir) + assert dict_["subset_name"] == subset_name + assert dict_["stratification"] == stratification + assert dict_["valid_longitudinal"] is False + assert dict_["n_test"] == n_test + assert np.isclose(dict_["p_categorical_threshold"], 0.5, rtol=1e-09, atol=1e-09) + assert np.isclose(dict_["p_continuous_threshold"], 0.5, rtol=1e-09, atol=1e-09) + + train_df = pd.read_csv(train_path, sep="\t") + test_df = pd.read_csv(test_path, sep="\t") + + assert len(test_df) == 15 + assert set(stratification).issubset(set(test_df.columns)) + + assert (split_dir / "split_continuous_stats.tsv").is_file() + assert (split_dir / "split_categorical_stats.tsv").is_file() + + split_dir_bis = make_split(sub_ses_t1, n_test=n_test) + + assert split_dir_bis == sub_ses_t1.parent / "split" + + split_dir_bis_bis = make_split(sub_ses_t1, n_test=n_test, stratification=False) + + assert split_dir_bis_bis == sub_ses_t1.parent / "split_2" + + remove_non_empty_dir(split_dir) + remove_non_empty_dir(split_dir_bis) + remove_non_empty_dir(split_dir_bis_bis) + + +def test_bad_split(): + with pytest.raises(ClinicaDLTSVError): + make_split(caps_dir / "test.tsv", n_test=15) + + with pytest.raises(ClinicaDLTSVError): + make_split(caps_dir / "subject_false.tsv", n_test=2) + + with pytest.raises(ValueError): + make_split(sub_ses_t1, p_categorical_threshold=12, n_test=2) + + with pytest.raises(ValueError): + make_split(sub_ses_t1, n_test=100) + + split_dir = sub_ses_t1.parent / "split" + remove_non_empty_dir(split_dir) + + +def test_good_kfold(): + n_split = 2 + stratification = "sex" + subset_name = "test_test" + + split_dir = make_kfold( + sub_ses_t1, + output_dir=caps_dir / "test", + subset_name=subset_name, + stratification=stratification, + n_splits=n_split, + ) + + train_path = split_dir / "split-0" / "train_baseline.tsv" + test_path = split_dir / "split-0" / f"{subset_name}_baseline.tsv" + + assert train_path.exists() + assert test_path.exists() + + assert (split_dir / "kfold_config.json").is_file + with (split_dir / "kfold_config.json").open(mode="r") as file: + dict_ = json.load(file) + + assert dict_["json_name"] == "kfold_config.json" + assert dict_["split_dir"] == str(split_dir) + assert dict_["subset_name"] == subset_name + assert dict_["stratification"] == stratification + assert dict_["valid_longitudinal"] is False + assert dict_["n_splits"] == n_split + + test_df = pd.read_csv(test_path, sep="\t") + + assert len(test_df) == 20 + assert set([stratification]).issubset(set(test_df.columns)) + + split_dir_bis = make_kfold( + sub_ses_t1, n_splits=n_split, stratification=stratification + ) + + assert split_dir_bis == sub_ses_t1.parent / "2_fold" + split_dir_bis_bis = make_kfold(sub_ses_t1, n_splits=n_split, stratification=False) + + assert split_dir_bis_bis == sub_ses_t1.parent / "2_fold_2" + + remove_non_empty_dir(split_dir) + remove_non_empty_dir(split_dir_bis) + remove_non_empty_dir(split_dir_bis_bis) + + +def test_bad_kfold(): + with pytest.raises(ClinicaDLTSVError): + make_kfold(caps_dir / "test.tsv", output_dir=caps_dir / "test_kfold") + + with pytest.raises(ClinicaDLTSVError): + make_kfold( + caps_dir / "subject_false.tsv", + n_splits=1, + output_dir=caps_dir / "test_kfold", + ) + + with pytest.raises(ValueError): + make_kfold( + sub_ses_t1, + stratification="age", + output_dir=caps_dir / "test_kfold", + ) + + with pytest.raises(ValidationError): + make_kfold( + sub_ses_t1, + stratification=["sex", "age"], + output_dir=caps_dir / "test_kfold", + ) # type: ignore + + with pytest.raises(ClinicaDLConfigurationError): + make_kfold( + sub_ses_t1, stratification="column", output_dir=caps_dir / "test_kfold" + ) + + remove_non_empty_dir(caps_dir / "test_kfold") diff --git a/tests/unittests/splitter/test_splitter.py b/tests/unittests/splitter/test_splitter.py new file mode 100644 index 000000000..8cf260bfd --- /dev/null +++ b/tests/unittests/splitter/test_splitter.py @@ -0,0 +1,105 @@ +import json +from pathlib import Path + +import nibabel as nib +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.preprocessing import PreprocessingT1, PreprocessingT2 +from clinicadl.splitter.split import Split +from clinicadl.splitter.splitter.kfold import KFold, KFoldConfig +from clinicadl.splitter.splitter.single_split import SingleSplit, SingleSplitConfig +from clinicadl.splitter.splitter.splitter import ( + Splitter, + SplitterConfig, + SubjectsSessionsSplit, +) +from clinicadl.transforms import Transforms +from clinicadl.utils.enum import Preprocessing +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLCAPSError, + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) + +caps_dir = Path(__file__).parents[1] / "ressources" / "caps_example" +split_dir = caps_dir / "split_test" / "split" +fold_path = split_dir / "2_fold" + + +def test_single_splitter(): + config = SingleSplitConfig(split_dir=split_dir) + + assert config.subset_name == "test" + assert config.stratification is False + assert config.valid_longitudinal is False + assert np.isclose(config.p_categorical_threshold, 0.8, rtol=1e-09, atol=1e-09) + assert np.isclose(config.p_categorical_threshold, 0.8, rtol=1e-09, atol=1e-09) + assert config.json_name == "single_split_config.json" + assert config.n_test == 100 + + with pytest.raises(ValidationError): + SingleSplitConfig(split_dir=split_dir, p_categorical_threshold=12) + + +def test_single_split(): + splitter = SingleSplit(split_dir=split_dir) + + with pytest.raises(ClinicaDLTSVError): + splitter.get_splits( + dataset=CapsDataset(caps_dir, PreprocessingT1(), Transforms()) + ) + + with pytest.raises(FileNotFoundError): + splitter._read_split(Path("doesnt_exist")) + + with pytest.raises(FileNotFoundError): + splitter._read_split(caps_dir / "test") + + with pytest.raises(FileNotFoundError): + SingleSplit("doesnt_exist") + + +def test_kfold_splitter(): + config = KFoldConfig(split_dir=fold_path) + + assert config.subset_name == "validation" + assert config.stratification is False + assert config.valid_longitudinal is False + assert config.json_name == "kfold_config.json" + assert config.n_splits == 5 + + +def test_kfold(): + kfold = KFold(split_dir=fold_path) + config = kfold.config + assert config.subset_name == "validation" + assert config.stratification == "sex" + assert config.valid_longitudinal is False + assert config.json_name == "kfold_config.json" + assert config.n_splits == 2 + + assert isinstance(kfold.subjects_sessions_split[0], SubjectsSessionsSplit) + + with pytest.raises(ClinicaDLTSVError): + splits = list( + kfold.get_splits( + dataset=CapsDataset(caps_dir, PreprocessingT1(), Transforms()) + ) + ) + + with pytest.raises(FileNotFoundError): + kfold._read_split(Path("doesnt_exist")) + + with pytest.raises(FileNotFoundError): + kfold._read_split(caps_dir / "test") + + with pytest.raises(FileNotFoundError): + KFold("doesnt_exist") + + with pytest.raises(FileNotFoundError): + KFold(caps_dir / "test") diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index e0bdef815..07b07fd8f 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -158,8 +158,6 @@ def good_inputs(dummy_arguments): # assert c.transforms.size_reduction_factor == 5 # assert c.split.split == (0,) # assert c.early_stopping.min_delta == 0.0 - - # Test config manipulation # def test_assignment(dummy_arguments, training_config): c = training_config(**dummy_arguments) From 4d40ba6c3112e5263ea976778818861ef301045f Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:24:52 +0100 Subject: [PATCH 2/3] first try for a dictionary with exemple for transforms module (#688) --- clinicadl/dictionary/__init__.py | 0 clinicadl/dictionary/suffixes.py | 7 ++++++ clinicadl/dictionary/utils.py | 1 + clinicadl/dictionary/words.py | 32 ++++++++++++++++++++++++ clinicadl/transforms/extraction/base.py | 9 ++++--- clinicadl/transforms/extraction/image.py | 3 +-- clinicadl/transforms/extraction/patch.py | 3 +-- clinicadl/transforms/extraction/slice.py | 3 +-- clinicadl/transforms/transforms.py | 13 +++++----- clinicadl/transforms/utils.py | 6 +++-- 10 files changed, 59 insertions(+), 18 deletions(-) create mode 100644 clinicadl/dictionary/__init__.py create mode 100644 clinicadl/dictionary/suffixes.py create mode 100644 clinicadl/dictionary/utils.py create mode 100644 clinicadl/dictionary/words.py diff --git a/clinicadl/dictionary/__init__.py b/clinicadl/dictionary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/dictionary/suffixes.py b/clinicadl/dictionary/suffixes.py new file mode 100644 index 000000000..77e7da05a --- /dev/null +++ b/clinicadl/dictionary/suffixes.py @@ -0,0 +1,7 @@ +TSV = ".tsv" +JSON = ".json" +LOG = ".log" +PT = ".pt" +NII = ".nii" +GZ = ".gz" +NII_GZ = NII + GZ diff --git a/clinicadl/dictionary/utils.py b/clinicadl/dictionary/utils.py new file mode 100644 index 000000000..54caa46d3 --- /dev/null +++ b/clinicadl/dictionary/utils.py @@ -0,0 +1 @@ +SEP = "\t" diff --git a/clinicadl/dictionary/words.py b/clinicadl/dictionary/words.py new file mode 100644 index 000000000..9c287c677 --- /dev/null +++ b/clinicadl/dictionary/words.py @@ -0,0 +1,32 @@ +AGE = "age" +AUGMENTATION = "augmentation" +BASELINE = "baseline" +BEST = "best" +CONFIG = "config" +COUNT = "count" +CUDA = "cuda" +DESCRIPTION = "description" +FOLD = "fold" +ID = "id" +IMAGE = "image" +KFOLD = "k" + FOLD +LABEL = "label" +MEAN = "mean" +OBJECT = "object" +PARTICIPANT = "participant" +PARTICIPANT_ID = PARTICIPANT + "_" + ID +PROPORTION = "proportion" +SAMPLE = "sample" +SEX = "sex" +SESSION = "session" +SESSION_ID = SESSION + "_" + ID +SINGLE = "single" +SPLIT = "split" +STATISTIC = "statistic" +STD = "std" +TEST = "test" +TMP = "tmp" +TRAIN = "train" +TRANSFORMATION = "transformation" +VALIDATION = "validation" +VALUE = "value" diff --git a/clinicadl/transforms/extraction/base.py b/clinicadl/transforms/extraction/base.py index bc15a0a92..34e1ba16f 100644 --- a/clinicadl/transforms/extraction/base.py +++ b/clinicadl/transforms/extraction/base.py @@ -9,6 +9,7 @@ import torchio as tio from pydantic import computed_field +from clinicadl.dictionary.words import IMAGE, LABEL, SAMPLE from clinicadl.utils.config import ClinicaDLConfig from clinicadl.utils.enum import ExtractionMethod @@ -244,7 +245,7 @@ def extract_tio_sample( IndexError If 'sample_index' is greater or equal to the number of samples in the image. """ - if not hasattr(tio_image, "image") or not isinstance( + if not hasattr(tio_image, IMAGE) or not isinstance( tio_image.image, tio.ScalarImage ): raise AttributeError( @@ -268,7 +269,7 @@ def extract_tio_sample( ) tio_sample.sample = tio_sample.image - delattr(tio_sample, "image") + delattr(tio_sample, IMAGE) return tio_sample @@ -278,14 +279,14 @@ def _check_tio_sample(tio_sample: tio.Subject): Checks that a TorchIO Subject is a valid sample, i.e. a sample with a TorchIO ScalarImage named 'sample', a label named 'label' and a description named 'description'. """ - if not hasattr(tio_sample, "sample") or not isinstance( + if not hasattr(tio_sample, SAMPLE) or not isinstance( tio_sample.sample, tio.ScalarImage ): raise AttributeError( "'tio_sample' must contain ScalarImage named 'image'. Got only the following images: " f"{tio_sample.get_images_names()}" ) - if not hasattr(tio_sample, "label"): + if not hasattr(tio_sample, LABEL): raise AttributeError( "'tio_sample' must contain an attribute named 'label'." ) diff --git a/clinicadl/transforms/extraction/image.py b/clinicadl/transforms/extraction/image.py index 87b301490..1aa08096c 100644 --- a/clinicadl/transforms/extraction/image.py +++ b/clinicadl/transforms/extraction/image.py @@ -6,14 +6,13 @@ import torchio as tio from pydantic import PositiveInt, computed_field +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ExtractionMethod from .base import Extraction, Sample logger = getLogger("clinicadl.extraction.image") -PT = ".pt" - class ImageSample(Sample): """ diff --git a/clinicadl/transforms/extraction/patch.py b/clinicadl/transforms/extraction/patch.py index 9473f6a7d..9b7e70e0c 100644 --- a/clinicadl/transforms/extraction/patch.py +++ b/clinicadl/transforms/extraction/patch.py @@ -6,14 +6,13 @@ import torchio as tio from pydantic import NonNegativeInt, PositiveInt, computed_field, field_validator +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ExtractionMethod from .base import Extraction, Sample logger = getLogger("clinicadl.extraction.patch") -PT = ".pt" - class PatchSample(Sample): """ diff --git a/clinicadl/transforms/extraction/slice.py b/clinicadl/transforms/extraction/slice.py index 59bd2c314..8396d82cc 100644 --- a/clinicadl/transforms/extraction/slice.py +++ b/clinicadl/transforms/extraction/slice.py @@ -14,6 +14,7 @@ ) from typing_extensions import Self +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ( ExtractionMethod, SliceDirection, @@ -24,8 +25,6 @@ logger = getLogger("clinicadl.extraction.slice") -PT = ".pt" - class SliceSample(Sample): """ diff --git a/clinicadl/transforms/transforms.py b/clinicadl/transforms/transforms.py index 4055f4218..b8c3565d4 100644 --- a/clinicadl/transforms/transforms.py +++ b/clinicadl/transforms/transforms.py @@ -4,6 +4,7 @@ import torchvision.transforms as torch_transforms from pydantic import model_validator +from clinicadl.dictionary.words import AUGMENTATION, IMAGE, OBJECT, TRANSFORMATION from clinicadl.transforms.extraction import Extraction, Image from clinicadl.transforms.factory import ( MinMaxNormalization, @@ -115,8 +116,8 @@ def __str__(self) -> str: def _to_str( list_: list[Callable] = [], - object_: str = "object", - transfo_: str = "transformation", + object_: str = OBJECT, + transfo_: str = TRANSFORMATION, ): str_ = "" if list_: @@ -128,13 +129,13 @@ def _to_str( return str_ - transform_str += _to_str(self.image_transforms, object_="image") - transform_str += _to_str(self.object_transforms, object_="object") + transform_str += _to_str(self.image_transforms, object_=IMAGE) + transform_str += _to_str(self.object_transforms, object_=OBJECT) transform_str += _to_str( - self.image_augmentation, object_="image", transfo_="augmentation" + self.image_augmentation, object_=IMAGE, transfo_=AUGMENTATION ) transform_str += _to_str( - self.object_augmentation, object_="object", transfo_="augmentation" + self.object_augmentation, object_=OBJECT, transfo_=AUGMENTATION ) return transform_str diff --git a/clinicadl/transforms/utils.py b/clinicadl/transforms/utils.py index 884edb620..f40a32e8a 100644 --- a/clinicadl/transforms/utils.py +++ b/clinicadl/transforms/utils.py @@ -3,6 +3,8 @@ import torch import torchio as tio +from clinicadl.dictionary.words import LABEL + def get_tio_image( image: torch.Tensor, @@ -32,9 +34,9 @@ def get_tio_image( tio_image = tio.Subject(image=tio.ScalarImage(tensor=image)) if isinstance(label, torch.Tensor): - tio_image.add_image(tio.LabelMap(tensor=label), "label") + tio_image.add_image(tio.LabelMap(tensor=label), LABEL) else: - setattr(tio_image, "label", label) + setattr(tio_image, LABEL, label) for name, mask in masks.items(): tio_image.add_image(tio.LabelMap(tensor=mask), name) From dc3443ee7ddf50aa82cec62c52ee9a9c9d075d09 Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:06:54 +0100 Subject: [PATCH 3/3] Cleaning (#689) --- clinicadl/API/complicated_case.py | 88 ++--- clinicadl/API/cross_val.py | 27 +- clinicadl/API/dataset_test.py | 61 ++-- clinicadl/API/single_split.py | 100 +++--- clinicadl/API_test.py | 225 ------------ clinicadl/commandline/arguments.py | 54 --- .../commandline/modules_options/callbacks.py | 20 -- .../modules_options/computational.py | 26 -- clinicadl/commandline/modules_options/data.py | 64 ---- .../commandline/modules_options/dataloader.py | 30 -- .../modules_options/early_stopping.py | 21 -- .../commandline/modules_options/extraction.py | 109 ------ .../modules_options/lr_scheduler.py | 9 - .../modules_options/maps_manager.py | 21 -- .../commandline/modules_options/network.py | 20 -- .../modules_options/optimization.py | 30 -- .../commandline/modules_options/optimizer.py | 30 -- .../modules_options/preprocessing.py | 73 ---- .../modules_options/reproducibility.py | 43 --- .../commandline/modules_options/split.py | 25 -- .../commandline/modules_options/transforms.py | 22 -- .../commandline/modules_options/validation.py | 39 --- .../pipelines/generate/artifacts/cli.py | 4 +- .../pipelines/generate/artifacts/options.py | 65 ---- .../pipelines/generate/hypometabolic/cli.py | 4 +- .../generate/hypometabolic/options.py | 29 -- .../pipelines/generate/random/cli.py | 4 +- .../pipelines/generate/random/options.py | 20 -- .../pipelines/generate/shepplogan/options.py | 50 --- .../pipelines/generate/trivial/cli.py | 4 +- .../pipelines/generate/trivial/options.py | 15 - .../pipelines/interpret/__init__.py | 0 .../commandline/pipelines/interpret/cli.py | 56 --- .../pipelines/interpret/options.py | 40 --- .../commandline/pipelines/predict/__init__.py | 0 .../commandline/pipelines/predict/cli.py | 69 ---- .../commandline/pipelines/predict/options.py | 20 -- .../pipelines/prepare_data/__init__.py | 0 .../prepare_data/prepare_data_cli.py | 162 --------- .../prepare_data_from_bids_cli.py | 148 -------- .../pipelines/quality_check/pet_linear/cli.py | 2 +- .../pipelines/quality_check/t1_linear/cli.py | 2 +- .../commandline/pipelines/train/__init__.py | 3 - .../train/classification/__init__.py | 0 .../pipelines/train/classification/cli.py | 112 ------ .../pipelines/train/classification/options.py | 52 --- clinicadl/commandline/pipelines/train/cli.py | 26 -- .../pipelines/train/from_json/__init__.py | 1 - .../pipelines/train/from_json/cli.py | 34 -- .../pipelines/train/list_models/__init__.py | 1 - .../pipelines/train/list_models/cli.py | 33 -- .../train/reconstruction/__init__.py | 0 .../pipelines/train/reconstruction/cli.py | 108 ------ .../pipelines/train/reconstruction/options.py | 36 -- .../pipelines/train/regression/__init__.py | 0 .../pipelines/train/regression/cli.py | 107 ------ .../pipelines/train/regression/options.py | 45 --- .../pipelines/train/resume/__init__.py | 0 .../commandline/pipelines/train/resume/cli.py | 19 - .../pipelines/transfer_learning/__init__.py | 0 .../pipelines/transfer_learning/options.py | 30 -- clinicadl/config/config/__init__.py | 0 clinicadl/config/config/lr_scheduler.py | 13 - clinicadl/config/config/reproducibility.py | 21 -- .../modules_options => data}/__init__.py | 0 .../{dataset => data}/config/__init__.py | 0 clinicadl/{dataset => data}/config/data.py | 0 .../{dataset => data}/config/file_type.py | 0 .../{dataset => data}/dataloader/__init__.py | 0 .../{dataset => data}/dataloader/config.py | 2 +- .../{dataset => data}/dataloader/defaults.py | 0 .../datasets/__init__.py} | 0 .../datasets/caps_dataset.py | 8 +- .../{dataset => data}/datasets/concat.py | 4 +- .../preprocessing/__init__.py | 0 .../{dataset => data}/preprocessing/base.py | 0 .../{dataset => data}/preprocessing/custom.py | 2 +- .../{dataset => data}/preprocessing/dti.py | 2 +- .../{dataset => data}/preprocessing/flair.py | 2 +- .../{dataset => data}/preprocessing/pet.py | 2 +- .../{dataset => data}/preprocessing/t1.py | 2 +- .../{dataset => data}/preprocessing/t2.py | 2 +- .../{dataset => data}/readers/__init__.py | 0 .../{dataset => data}/readers/bids_reader.py | 4 +- .../{dataset => data}/readers/caps_reader.py | 6 +- .../readers/multi_caps_reader.py | 0 clinicadl/{dataset => data}/readers/reader.py | 0 clinicadl/{dataset => data}/utils.py | 4 +- clinicadl/dataset/__init__.py | 0 clinicadl/experiment_manager/__init__.py | 1 + .../experiment_manager/experiment_manager.py | 4 +- clinicadl/experiment_manager/maps_manager.py | 2 +- clinicadl/hugging_face/hugging_face.py | 2 +- clinicadl/interpret/config.py | 2 +- clinicadl/model/clinicadl_model.py | 57 ++- clinicadl/predictor/config.py | 2 +- clinicadl/predictor/old_predictor.py | 2 +- clinicadl/predictor/predictor.py | 2 +- .../quality_check/pet_linear/quality_check.py | 4 +- .../quality_check/t1_linear/quality_check.py | 2 +- clinicadl/quality_check/t1_linear/utils.py | 6 +- clinicadl/splitter/make_splits/kfold.py | 38 +- .../splitter/make_splits/single_split.py | 31 +- clinicadl/splitter/split.py | 6 +- clinicadl/splitter/splitter/kfold.py | 2 +- clinicadl/splitter/splitter/single_split.py | 2 +- clinicadl/splitter/splitter/splitter.py | 2 +- clinicadl/splitter/test.py | 4 +- clinicadl/tmp_config.py | 4 +- clinicadl/trainer/__init__.py | 1 + clinicadl/trainer/config/classification.py | 2 +- clinicadl/trainer/config/regression.py | 2 +- clinicadl/trainer/config/train.py | 2 +- clinicadl/trainer/old_trainer.py | 4 +- clinicadl/trainer/tasks_utils.py | 2 +- clinicadl/trainer/trainer.py | 221 +++++++++++- clinicadl/{config => utils}/config_utils.py | 0 clinicadl/utils/iotools/train_utils.py | 4 +- tests/unittests/dataset/test_config.py | 2 +- tests/unittests/dataset/test_datasets.py | 4 +- tests/unittests/dataset/test_reader.py | 4 +- .../test_random_search_config.py | 120 +++---- tests/unittests/splitter/test_make_split.py | 21 ++ tests/unittests/splitter/test_splitter.py | 4 +- .../test_classification_config.py | 172 ++++----- .../test_reconstruction_config.py | 114 +++--- .../regression/test_regression_config.py | 110 +++--- .../train/trainer/test_training_config.py | 326 +++++++++--------- tests/unittests/utils/test_config_utils.py | 4 +- 129 files changed, 936 insertions(+), 2870 deletions(-) delete mode 100644 clinicadl/API_test.py delete mode 100644 clinicadl/commandline/arguments.py delete mode 100644 clinicadl/commandline/modules_options/callbacks.py delete mode 100644 clinicadl/commandline/modules_options/computational.py delete mode 100644 clinicadl/commandline/modules_options/data.py delete mode 100644 clinicadl/commandline/modules_options/dataloader.py delete mode 100644 clinicadl/commandline/modules_options/early_stopping.py delete mode 100644 clinicadl/commandline/modules_options/extraction.py delete mode 100644 clinicadl/commandline/modules_options/lr_scheduler.py delete mode 100644 clinicadl/commandline/modules_options/maps_manager.py delete mode 100644 clinicadl/commandline/modules_options/network.py delete mode 100644 clinicadl/commandline/modules_options/optimization.py delete mode 100644 clinicadl/commandline/modules_options/optimizer.py delete mode 100644 clinicadl/commandline/modules_options/preprocessing.py delete mode 100644 clinicadl/commandline/modules_options/reproducibility.py delete mode 100644 clinicadl/commandline/modules_options/split.py delete mode 100644 clinicadl/commandline/modules_options/transforms.py delete mode 100644 clinicadl/commandline/modules_options/validation.py delete mode 100644 clinicadl/commandline/pipelines/generate/artifacts/options.py delete mode 100644 clinicadl/commandline/pipelines/generate/hypometabolic/options.py delete mode 100644 clinicadl/commandline/pipelines/generate/random/options.py delete mode 100644 clinicadl/commandline/pipelines/generate/shepplogan/options.py delete mode 100644 clinicadl/commandline/pipelines/generate/trivial/options.py delete mode 100644 clinicadl/commandline/pipelines/interpret/__init__.py delete mode 100644 clinicadl/commandline/pipelines/interpret/cli.py delete mode 100644 clinicadl/commandline/pipelines/interpret/options.py delete mode 100644 clinicadl/commandline/pipelines/predict/__init__.py delete mode 100644 clinicadl/commandline/pipelines/predict/cli.py delete mode 100644 clinicadl/commandline/pipelines/predict/options.py delete mode 100644 clinicadl/commandline/pipelines/prepare_data/__init__.py delete mode 100644 clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py delete mode 100644 clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py delete mode 100644 clinicadl/commandline/pipelines/train/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/classification/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/classification/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/classification/options.py delete mode 100644 clinicadl/commandline/pipelines/train/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/from_json/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/from_json/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/list_models/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/list_models/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/reconstruction/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/reconstruction/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/reconstruction/options.py delete mode 100644 clinicadl/commandline/pipelines/train/regression/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/regression/cli.py delete mode 100644 clinicadl/commandline/pipelines/train/regression/options.py delete mode 100644 clinicadl/commandline/pipelines/train/resume/__init__.py delete mode 100644 clinicadl/commandline/pipelines/train/resume/cli.py delete mode 100644 clinicadl/commandline/pipelines/transfer_learning/__init__.py delete mode 100644 clinicadl/commandline/pipelines/transfer_learning/options.py delete mode 100644 clinicadl/config/config/__init__.py delete mode 100644 clinicadl/config/config/lr_scheduler.py delete mode 100644 clinicadl/config/config/reproducibility.py rename clinicadl/{commandline/modules_options => data}/__init__.py (100%) rename clinicadl/{dataset => data}/config/__init__.py (100%) rename clinicadl/{dataset => data}/config/data.py (100%) rename clinicadl/{dataset => data}/config/file_type.py (100%) rename clinicadl/{dataset => data}/dataloader/__init__.py (100%) rename clinicadl/{dataset => data}/dataloader/config.py (98%) rename clinicadl/{dataset => data}/dataloader/defaults.py (100%) rename clinicadl/{dataset/datasets/___init__.py => data/datasets/__init__.py} (100%) rename clinicadl/{dataset => data}/datasets/caps_dataset.py (98%) rename clinicadl/{dataset => data}/datasets/concat.py (93%) rename clinicadl/{dataset => data}/preprocessing/__init__.py (100%) rename clinicadl/{dataset => data}/preprocessing/base.py (100%) rename clinicadl/{dataset => data}/preprocessing/custom.py (93%) rename clinicadl/{dataset => data}/preprocessing/dti.py (95%) rename clinicadl/{dataset => data}/preprocessing/flair.py (92%) rename clinicadl/{dataset => data}/preprocessing/pet.py (96%) rename clinicadl/{dataset => data}/preprocessing/t1.py (92%) rename clinicadl/{dataset => data}/preprocessing/t2.py (93%) rename clinicadl/{dataset => data}/readers/__init__.py (100%) rename clinicadl/{dataset => data}/readers/bids_reader.py (98%) rename clinicadl/{dataset => data}/readers/caps_reader.py (98%) rename clinicadl/{dataset => data}/readers/multi_caps_reader.py (100%) rename clinicadl/{dataset => data}/readers/reader.py (100%) rename clinicadl/{dataset => data}/utils.py (99%) delete mode 100644 clinicadl/dataset/__init__.py rename clinicadl/{config => utils}/config_utils.py (100%) diff --git a/clinicadl/API/complicated_case.py b/clinicadl/API/complicated_case.py index 4afe5050a..b16c653f9 100644 --- a/clinicadl/API/complicated_case.py +++ b/clinicadl/API/complicated_case.py @@ -2,41 +2,26 @@ import torchio.transforms as transforms -from clinicadl.dataset.dataloader_config import DataLoaderConfig -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.datasets.concat import ConcatDataset -from clinicadl.dataset.preprocessing import ( - PreprocessingCustom, +from clinicadl.data.dataloader import DataLoaderConfig +from clinicadl.data.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.concat import ConcatDataset +from clinicadl.data.preprocessing import ( PreprocessingPET, PreprocessingT1, ) -from clinicadl.dataset.readers.caps_reader import CapsReader from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig -from clinicadl.losses.factory import get_loss_function from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.networks.config import ImplementedNetworks -from clinicadl.networks.factory import ( - ConvEncoderOptions, - create_network_config, - get_network_from_config, -) -from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig -from clinicadl.optimization.optimizer.factory import get_optimizer -from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.networks.config.resnet import ResNetConfig +from clinicadl.optim.optimizers.config import AdamConfig +from clinicadl.splitter import KFold, make_kfold, make_split from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice +from clinicadl.transforms.extraction import Extraction, Image, Patch, Slice from clinicadl.transforms.transforms import Transforms -# Create the Maps Manager / Read/write manager / -maps_path = Path("/") -manager = ExperimentManager( - maps_path, overwrite=False -) # a ajouter dans le manager: mlflow/ profiler/ etc ... - -caps_directory = Path("caps_directory") # output of clinica pipelines +caps_directory = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps" +) # output of clinica pipelines sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") preprocessing_t1 = PreprocessingT1() @@ -60,7 +45,7 @@ sub_ses_pet_45 = Path( "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv" ) -preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") +preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") # type: ignore dataset_pet_image = CapsDataset( caps_directory=caps_directory, @@ -79,47 +64,27 @@ ) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention -config_file = Path("config_file") -trainer = Trainer.from_json(config_file=config_file, manager=manager) - # CAS CROSS-VALIDATION -splitter = KFolder(caps_dataset=dataset_multi_modality_multi_extract, manager=manager) -split_dir = splitter.make_splits( - n_splits=3, output_dir=Path(""), subset_name="validation", stratification="" -) # Optional data tsv and output_dir -dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) +split_dir = make_split(sub_ses_t1, n_test=0.2) # Optional data tsv and output_dir +fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2) +splitter = KFold(fold_dir) -# CAS 1 -# Prérequis : déjà avoir des fichiers avec les listes train et validation -split_dir = make_kfold( - "dataset.tsv" -) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir -splitter = KFolder( - dataset_multi_modality, split_dir -) # c'est plutôt un iterable de dataloader +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) -# CAS 2 -splitter = KFolder(caps_dataset=dataset_t1_image) -splitter.make_splits(n_splits=3) -splitter.write(split_dir) +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) -# or -splitter = KFolder(caps_dataset=dataset_t1_image) -splitter.read(split_dir) -for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): - # bien définir ce qu'il y a dans l'objet split +for split in splitter.get_splits(dataset=dataset_t1_image): + train_loader = split.build_train_loader(batch_size=2) + val_loader = split.build_val_loader(DataLoaderConfig()) - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - model = ClinicaDLModelClassif.from_config( - network_config=network_config, + model = ClinicaDLModel.from_config( + network_config=ResNetConfig(num_outputs=1, spatial_dims=1, in_channels=1), loss_config=CrossEntropyLossConfig(), optimizer_config=AdamConfig(), ) @@ -133,9 +98,6 @@ dataset_test = CapsDataset( caps_directory=caps_directory, preprocessing=preprocessing_t1, - sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset + data=Path("test.tsv"), # test only on data from the first dataset transforms=transforms_image, ) - -predictor = Predictor(model=model, manager=manager) -predictor.predict(dataset_test=dataset_test, split_number=2) diff --git a/clinicadl/API/cross_val.py b/clinicadl/API/cross_val.py index 54f7b9d6b..dbdb7f6ce 100644 --- a/clinicadl/API/cross_val.py +++ b/clinicadl/API/cross_val.py @@ -1,11 +1,10 @@ from pathlib import Path -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig -from clinicadl.splitter.new_splitter.splitter.kfold import KFold -from clinicadl.trainer.trainer import Trainer +from clinicadl.data.dataloader import DataLoaderConfig +from clinicadl.data.datasets import CapsDataset +from clinicadl.experiment_manager import ExperimentManager +from clinicadl.splitter import KFold, make_kfold, make_split +from clinicadl.trainer import Trainer # SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING @@ -19,20 +18,18 @@ config_file=config_file, manager=manager ) # gpu, amp, fsdp, seed -splitter = KFold(dataset=dataset_t1_image) -splitter.make_splits(n_splits=3) -split_dir = Path("") -splitter.write(split_dir) +split_dir = make_split( + dataset_t1_image.df, n_test=0.2, subset_name="validation", output_dir="test" +) # Optional data tsv and output_dir +fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2) + +splitter = KFold(fold_dir) -splitter.read(split_dir) # define the needed parameters for the dataloader dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10) -for split in splitter.get_splits(splits=(0, 3, 4)): - print(split) +for split in splitter.get_splits(dataset=dataset_t1_image): split.build_train_loader(dataloader_config) split.build_val_loader(num_workers=3, batch_size=10) - - print(split) diff --git a/clinicadl/API/dataset_test.py b/clinicadl/API/dataset_test.py index bc8e2eb92..dc2f14e43 100644 --- a/clinicadl/API/dataset_test.py +++ b/clinicadl/API/dataset_test.py @@ -2,15 +2,15 @@ import torchio.transforms as transforms -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.datasets.concat import ConcatDataset -from clinicadl.dataset.preprocessing import ( +from clinicadl.data.datasets import CapsDataset, ConcatDataset +from clinicadl.data.preprocessing import ( BasePreprocessing, PreprocessingFlair, PreprocessingPET, PreprocessingT1, ) -from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.data.preprocessing.pet import SUVRReferenceRegions, Tracer +from clinicadl.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig from clinicadl.model.clinicadl_model import ClinicaDLModel from clinicadl.networks.factory import ( @@ -18,10 +18,9 @@ create_network_config, get_network_from_config, ) -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv -from clinicadl.transforms.extraction import ROI, Image, Patch, Slice -from clinicadl.transforms.transforms import Transforms +from clinicadl.splitter import KFold, make_kfold, make_split +from clinicadl.transforms import Transforms +from clinicadl.transforms.extraction import Image, Patch, Slice sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") sub_ses_pet_45 = Path( @@ -38,8 +37,12 @@ "/Users/camille.brianceau/aramis/CLINICADL/caps" ) # output of clinica pipelines -preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") -preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2") +preprocessing_pet_45 = PreprocessingPET( + tracer=Tracer.FAV45, suvr_reference_region=SUVRReferenceRegions.PONS2 +) +preprocessing_pet_11 = PreprocessingPET( + tracer=Tracer.CPIB, suvr_reference_region=SUVRReferenceRegions.PONS2 +) preprocessing_t1 = PreprocessingT1() preprocessing_flair = PreprocessingFlair() @@ -55,18 +58,6 @@ transforms_slice = Transforms(extraction=Slice()) -transforms_roi = Transforms( - object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)], - object_transforms=[transforms.RandomMotion()], - extraction=ROI( - roi_list=["leftHippocampusBox", "rightHippocampusBox"], - roi_mask_location=Path( - "/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym" - ), - roi_crop_input=True, - ), -) - transforms_image = Transforms( image_augmentation=[transforms.RandomMotion()], extraction=Image(), @@ -96,25 +87,25 @@ ) -print("Pet 11 and ROI ") +print("Pet 11 and Image ") -dataset_pet_11_roi = CapsDataset( +dataset_pet_11_image = CapsDataset( caps_directory=caps_directory, data=sub_ses_pet_11, preprocessing=preprocessing_pet_11, - transforms=transforms_roi, + transforms=transforms_image, ) -dataset_pet_11_roi.prepare_data( +dataset_pet_11_image.prepare_data( n_proc=2 ) # to extract the tensor of the PET file this time -print(dataset_pet_11_roi) -print(dataset_pet_11_roi.__len__()) -print(dataset_pet_11_roi._get_meta_data(0)) -print(dataset_pet_11_roi._get_meta_data(1)) -# print(dataset_pet_11_roi._get_full_image()) -print(dataset_pet_11_roi.__getitem__(1).elem_idx) -print(dataset_pet_11_roi.elem_per_image) +print(dataset_pet_11_image) +print(dataset_pet_11_image.__len__()) +print(dataset_pet_11_image._get_meta_data(0)) +print(dataset_pet_11_image._get_meta_data(1)) +# print(dataset_pet_11_image._get_full_image()) +print(dataset_pet_11_image.__getitem__(1).elem_idx) +print(dataset_pet_11_image.elem_per_image) print("T1 and image ") @@ -161,7 +152,7 @@ lity_multi_extract = ConcatDataset( [ - dataset_t1, - dataset_pet, + dataset_t1_image, + dataset_pet_11_image, ] ) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention diff --git a/clinicadl/API/single_split.py b/clinicadl/API/single_split.py index af0f0a699..3498551a8 100644 --- a/clinicadl/API/single_split.py +++ b/clinicadl/API/single_split.py @@ -2,89 +2,71 @@ import torchio.transforms as transforms -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig, ExtractionPatchConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, +from clinicadl.data.dataloader import DataLoaderConfig +from clinicadl.data.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.concat import ConcatDataset +from clinicadl.data.preprocessing import ( + PreprocessingPET, + PreprocessingT1, ) from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig -from clinicadl.losses.factory import get_loss_function from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.networks.config import ImplementedNetworks -from clinicadl.networks.factory import ( - ConvEncoderOptions, - create_network_config, - get_network_from_config, -) -from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig -from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.networks.config.resnet import ResNetConfig +from clinicadl.optim.optimizers.config import AdamConfig from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.splitter.make_splits import make_kfold, make_split +from clinicadl.splitter.splitter import KFold, SingleSplit from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.config import TransformsConfig +from clinicadl.transforms.extraction import Image from clinicadl.transforms.transforms import Transforms -from clinicadl.utils.enum import ExtractionMethod - -# SIMPLE EXPERIMENT +caps_directory = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps" +) # output of clinica pipelines -caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory) -# un peu bizarre de passer un maps_path a cet endroit via le manager pq on veut pas forcmeent faire un entrainement ?? +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +preprocessing_t1 = PreprocessingT1() +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) -preprocessing_t1 = caps_reader.get_preprocessing("t1-linear") -caps_reader.prepare_data( +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_t1, preprocessing=preprocessing_t1, - data_tsv=Path(""), - n_proc=2, - use_uncropped_images=False, + transforms=transforms_image, ) -transforms_1 = Transforms( - object_augmentation=[transforms.RandomMotion()], # default = no transforms - image_augmentation=[transforms.RandomMotion()], # default = no transforms - object_transforms=[transforms.Blur((0.4, 0.5, 0.6))], # default = none - image_transforms=[transforms.Noise(0.2, 0.5, 3)], # default = MiniMax - extraction=ExtractionPatchConfig(patch_size=30, stride_size=20), # default = Image -) # not mandatory +dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the T1 file -sub_ses_tsv = Path("") -split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv -dataset_t1_image = caps_reader.get_dataset( - preprocessing=preprocessing_t1, - sub_ses_tsv=split_dir / "train.tsv", - transforms=transforms_1, -) # do we give config or ob -> dataset.json -# we can create a dataset.json in the CAPS ? or elsewhere ? -# but maybe we need to create a json file with the infos from the dataset (preprocessing, tsv file, transforms options and caps_directory) -dataset_t1_image = caps_reader.get_dataset_from_json("dataset.json") +split_dir = make_split( + sub_ses_t1, n_test=0.2, subset_name="test" +) # Optional data tsv and output_dir +split_dir_val = make_split( + split_dir / "train.tsv", n_test=0.2, subset_name="validation" +) + +splitter = SingleSplit(split_dir_val) -# CAS SINGLE SPLIT -split = get_single_split( - n_subject_validation=0, - caps_dataset=dataset_t1_image, - # manager=manager, -) # as we said, maybe we do not need to pass the manager in this function maps_path = Path("/") manager = ExperimentManager(maps_path, overwrite=False) config_file = Path("config_file") trainer = Trainer.from_json(config_file=config_file, manager=manager) -# how to create the trainer not from a config file ? -network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2]) -) -model = ClinicaDLModelClassif.from_config( - network_config=network_config, +split = splitter.get_splits(dataset=dataset_t1_image) + +train_loader = split.build_train_loader(batch_size=2) +val_loader = split.build_val_loader(DataLoaderConfig()) + +model = ClinicaDLModel.from_config( + network_config=ResNetConfig(num_outputs=1, spatial_dims=1, in_channels=1), loss_config=CrossEntropyLossConfig(), optimizer_config=AdamConfig(), ) trainer.train(model, split) -# le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py deleted file mode 100644 index 0639f1ffb..000000000 --- a/clinicadl/API_test.py +++ /dev/null @@ -1,225 +0,0 @@ -from pathlib import Path - -import torchio - -from clinicadl.dataset.caps_dataset import ( - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) -from clinicadl.dataset.caps_reader import CapsReader -from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig -from clinicadl.dataset.config.preprocessing import ( - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.losses.config import CrossEntropyLossConfig -from clinicadl.losses.factory import get_loss_function -from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.networks.config import ImplementedNetworks -from clinicadl.networks.factory import ( - ConvEncoderOptions, - create_network_config, - get_network_from_config, -) -from clinicadl.optim.optimizers.config import AdamConfig, OptimizerConfig -from clinicadl.optim.optimizers.factory import get_optimizer -from clinicadl.predictor.predictor import Predictor -from clinicadl.splitter.kfold import KFolder -from clinicadl.splitter.split import get_single_split, split_tsv -from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.transforms import Transforms - -# Create the Maps Manager / Read/write manager / -maps_path = Path("/") -manager = ExperimentManager(maps_path, overwrite=False) - -caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) - -preprocessing_1 = caps_reader.get_preprocessing("t1-linear") -extraction_1 = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice=2) -transforms_1 = Transforms( - data_augmentation=[torchio.t1, torchio.t2], - image_transforms=[torchio.t1, torchio.t2], - object_transforms=[torchio.t1, torchio.t2], -) # not mandatory - -preprocessing_2 = caps_reader.get_preprocessing("pet-linear") -extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2) -transforms_2 = Transforms( - data_augmentation=[torchio.t2], - image_transforms=[torchio.t1], - object_transforms=[torchio.t1, torchio.t2], -) - -sub_ses_tsv = Path("") -split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv - -dataset_t1_roi = caps_reader.get_dataset( - extraction=extraction_1, - preprocessing=preprocessing_1, - sub_ses_tsv=split_dir / "train.tsv", - transforms=transforms_1, -) # do we give config or object for transforms ? -dataset_pet_patch = caps_reader.get_dataset( - extraction=extraction_2, - preprocessing=preprocessing_2, - sub_ses_tsv=split_dir / "train.tsv", - transforms=transforms_2, -) - -dataset_multi_modality_multi_extract = ConcatDataset( - [dataset_t1_roi, dataset_pet_patch] -) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention - -config_file = Path("config_file") -trainer = Trainer.from_json(config_file=config_file, manager=manager) - -# CAS CROSS-VALIDATION -splitter = KFolder( - n_splits=3, caps_dataset=dataset_multi_modality_multi_extract, manager=manager -) - -for split in splitter.split_iterator(split_list=[0, 1]): - # bien définir ce qu'il y a dans l'objet split - - loss, loss_config = get_loss_function(CrossEntropyLossConfig()) - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - network, _ = get_network_from_config(network_config) - optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) - - trainer.train(model, split) - # le trainer va instancier un predictor/valdiator dans le train ou dans le init - - -# CAS SINGLE SPLIT -split = get_single_split( - n_subject_validation=0, - caps_dataset=dataset_multi_modality_multi_extract, - manager=manager, -) - -loss, loss_config = get_loss_function(CrossEntropyLossConfig()) -network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2]) -) -network, _ = get_network_from_config(network_config) -optimizer, _ = get_optimizer(network, AdamConfig()) -model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) - -trainer.train(model, split) -# le trainer va instancier un predictor/valdiator dans le train ou dans le init - - -# TEST - -preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") -extraction_test: ExtractionConfig = caps_reader.extract_patch( - preprocessing=preprocessing_2, arg_patch=2 -) -transforms_test = Transforms( - data_augmentation=[torchio.t2], - image_transforms=[torchio.t1], - object_transforms=[torchio.t1, torchio.t2], -) - -dataset_test = caps_reader.get_dataset( - extraction=extraction_test, - preprocessing=preprocessing_test, - sub_ses_tsv=split_dir / "test.tsv", - transforms=transforms_test, -) - -predictor = Predictor(manager=manager) -predictor.predict(dataset_test=dataset_test, split=2) - - -# SIMPLE EXPERIMENT - - -maps_path = Path("/") -manager = ExperimentManager(maps_path, overwrite=False) - -caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) - -extraction_1 = caps_reader.extract_image(preprocessing=T1PreprocessingConfig()) -transforms_1 = Transforms( - data_augmentation=[torchio.transforms.RandomMotion] -) # not mandatory - -sub_ses_tsv = Path("") -split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv - -dataset_t1_image = caps_reader.get_dataset( - extraction=extraction_1, - preprocessing=T1PreprocessingConfig(), - sub_ses_tsv=split_dir / "train.tsv", - transforms=transforms_1, -) # do we give config or ob - - -config_file = Path("config_file") -trainer = Trainer.from_json(config_file=config_file, manager=manager) - -# CAS CROSS-VALIDATION -splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) - -for split in splitter.split_iterator(split_list=[0, 1]): - # bien définir ce qu'il y a dans l'objet split - - loss, loss_config = get_loss_function(CrossEntropyLossConfig()) - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - network, _ = get_network_from_config(network_config) - optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) - - trainer.train(model, split) - # le trainer va instancier un predictor/valdiator dans le train ou dans le init - - -# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING - -maps_path = Path("/") -manager = ExperimentManager(maps_path, overwrite=False) - -# sub_ses_tsv = Path("") -# split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv - -dataset_t1_image = CapsDatasetPatch.from_json( - extraction=extract_json, - sub_ses_tsv=split_dir / "train.tsv", -) -config_file = Path("config_file") -trainer = Trainer.from_json(config_file=config_file, manager=manager) - -# CAS CROSS-VALIDATION -splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) - -for split in splitter.split_iterator(split_list=[0, 1]): - # bien définir ce qu'il y a dans l'objet split - - loss, loss_config = get_loss_function(CrossEntropyLossConfig()) - network_config = create_network_config(ImplementedNetworks.CNN)( - in_shape=[2, 2, 2], - num_outputs=1, - conv_args=ConvEncoderOptions(channels=[3, 2, 2]), - ) - network, _ = get_network_from_config(network_config) - optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) - - trainer.train(model, split) - # le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/commandline/arguments.py b/clinicadl/commandline/arguments.py deleted file mode 100644 index 76c6ad8c6..000000000 --- a/clinicadl/commandline/arguments.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Common CLI arguments used by ClinicaDL pipelines.""" - -from pathlib import Path - -import click - -# TODO trier les arguments par configclasses et voir les arguments utils et ceux qui ne le sont pas -bids_directory = click.argument( - "bids_directory", type=click.Path(exists=True, path_type=Path) -) -caps_directory = click.argument("caps_directory", type=click.Path(path_type=Path)) -input_maps = click.argument( - "input_maps_directory", type=click.Path(exists=True, path_type=Path) -) -output_maps = click.argument("output_maps_directory", type=click.Path(path_type=Path)) -results_tsv = click.argument("results_tsv", type=click.Path(path_type=Path)) -data_tsv = click.argument("data_tsv", type=click.Path(exists=True, path_type=Path)) -# ANALYSIS -merged_tsv = click.argument("merged_tsv", type=click.Path(exists=True, path_type=Path)) - -# TSV TOOLS -tsv_path = click.argument("tsv_path", type=click.Path(exists=True, path_type=Path)) -old_tsv_dir = click.argument( - "old_tsv_dir", type=click.Path(exists=True, path_type=Path) -) -new_tsv_dir = click.argument("new_tsv_dir", type=click.Path(path_type=Path)) -output_directory = click.argument("output_directory", type=click.Path(path_type=Path)) -dataset = click.argument("dataset", type=click.Choice(["AIBL", "OASIS"])) - - -# TRAIN -preprocessing_json = click.argument("preprocessing_json", type=str) - -modality_bids = click.argument( - "modality_bids", - type=click.Choice(["t1", "pet", "flair", "dwi", "custom"]), -) -tracer = click.argument( - "tracer", - type=str, -) -suvr_reference_region = click.argument( - "suvr_reference_region", - type=str, -) -generated_caps_directory = click.argument("generated_caps_directory", type=Path) - -data_group = click.argument("data_group", type=str) -config_file = click.argument( - "config_file", type=click.Path(exists=True, path_type=Path) -) -preprocessing = click.argument( - "preprocessing", type=click.Choice(["t1", "pet", "flair", "dwi", "custom"]) -) diff --git a/clinicadl/commandline/modules_options/callbacks.py b/clinicadl/commandline/modules_options/callbacks.py deleted file mode 100644 index 71028ce0d..000000000 --- a/clinicadl/commandline/modules_options/callbacks.py +++ /dev/null @@ -1,20 +0,0 @@ -import click - -from clinicadl.callbacks.config import CallbacksConfig -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type - -emissions_calculator = click.option( - "--calculate_emissions/--dont_calculate_emissions", - default=get_default("emissions_calculator", CallbacksConfig), - help="Flag to allow calculate the carbon emissions during training.", - show_default=True, -) -track_exp = click.option( - "--track_exp", - "-te", - type=get_type("track_exp", CallbacksConfig), - default=get_default("track_exp", CallbacksConfig), - help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/computational.py b/clinicadl/commandline/modules_options/computational.py deleted file mode 100644 index 5bc05e158..000000000 --- a/clinicadl/commandline/modules_options/computational.py +++ /dev/null @@ -1,26 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.utils.computational.computational import ComputationalConfig - -# Computational -amp = click.option( - "--amp/--no-amp", - default=get_default("amp", ComputationalConfig), - help="Enables automatic mixed precision during training and inference.", - show_default=True, -) -fully_sharded_data_parallel = click.option( - "--fully_sharded_data_parallel", - "-fsdp", - is_flag=True, - help="Enables Fully Sharded Data Parallel with Pytorch to save memory at the cost of communications. " - "Currently this only enables ZeRO Stage 1 but will be entirely replaced by FSDP in a later patch, " - "this flag is already set to FSDP to that the zero flag is never actually removed.", -) -gpu = click.option( - "--gpu/--no-gpu", - default=get_default("gpu", ComputationalConfig), - help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/data.py b/clinicadl/commandline/modules_options/data.py deleted file mode 100644 index 9cf79e654..000000000 --- a/clinicadl/commandline/modules_options/data.py +++ /dev/null @@ -1,64 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.dataset.config.data import DataConfig - -# Data -baseline = click.option( - "--baseline/--longitudinal", - default=get_default("baseline", DataConfig), - help="If provided, only the baseline sessions are used for training.", - show_default=True, -) -diagnoses = click.option( - "--diagnoses", - "-d", - type=get_type("diagnoses", DataConfig), - default=get_default("diagnoses", DataConfig), - multiple=True, - help="List of diagnoses used for training.", - show_default=True, -) -multi_cohort = click.option( - "--multi_cohort/--single_cohort", - default=get_default("multi_cohort", DataConfig), - help="Performs multi-cohort training. In this case, caps_dir and tsv_path must be paths to TSV files.", - show_default=True, -) -participants_tsv = click.option( - "--participants_tsv", - type=get_type("data_tsv", DataConfig), - default=get_default("data_tsv", DataConfig), - help="Path to a TSV file including a list of participants/sessions.", - show_default=True, -) -n_subjects = click.option( - "--n_subjects", - type=get_type("n_subjects", DataConfig), - default=get_default("n_subjects", DataConfig), - help="Number of subjects in each class of the synthetic dataset.", -) -caps_directory = click.option( - "--caps_directory", - type=get_type("caps_directory", DataConfig), - default=get_default("caps_directory", DataConfig), - help="Data using CAPS structure, if different from the one used during network training.", - show_default=True, -) -label = click.option( - "--label", - type=get_type("label", DataConfig), - default=get_default("label", DataConfig), - show_default=True, - help="Target label used for training (if NETWORK_TASK in [`regression`, `classification`]). " - "Default will reuse the same label as during the training task.", -) -mask_path = click.option( - "--mask_path", - type=get_type("mask_path", DataConfig), - default=get_default("mask_path", DataConfig), - help="Path to the extracted masks to generate the two labels. " - "Default will try to download masks and store them at '~/.cache/clinicadl'.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/dataloader.py b/clinicadl/commandline/modules_options/dataloader.py deleted file mode 100644 index bf4d4c781..000000000 --- a/clinicadl/commandline/modules_options/dataloader.py +++ /dev/null @@ -1,30 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.dataset.dataloader_config import DataLoaderConfig - -# DataLoader -batch_size = click.option( - "--batch_size", - type=get_type("batch_size", DataLoaderConfig), - default=get_default("batch_size", DataLoaderConfig), - help="Batch size for data loading.", - show_default=True, -) -n_proc = click.option( - "-np", - "--n_proc", - type=get_type("n_proc", DataLoaderConfig), - default=get_default("n_proc", DataLoaderConfig), - help="Number of cores used during the task.", - show_default=True, -) -sampler = click.option( - "--sampler", - "-s", - type=get_type("sampler", DataLoaderConfig), - default=get_default("sampler", DataLoaderConfig), - help="Sampler used to load the training data set.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/early_stopping.py b/clinicadl/commandline/modules_options/early_stopping.py deleted file mode 100644 index 5fa3791e2..000000000 --- a/clinicadl/commandline/modules_options/early_stopping.py +++ /dev/null @@ -1,21 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.optim.early_stopping import EarlyStoppingConfig - -# Early Stopping -patience = click.option( - "--patience", - type=get_type("patience", EarlyStoppingConfig), - default=get_default("patience", EarlyStoppingConfig), - help="Number of epochs for early stopping patience.", - show_default=True, -) -tolerance = click.option( - "--tolerance", - type=get_type("min_delta", EarlyStoppingConfig), - default=get_default("min_delta", EarlyStoppingConfig), - help="Value for early stopping tolerance.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/extraction.py b/clinicadl/commandline/modules_options/extraction.py deleted file mode 100644 index e382eecc2..000000000 --- a/clinicadl/commandline/modules_options/extraction.py +++ /dev/null @@ -1,109 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.dataset.config.extraction import ( - ExtractionConfig, - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) - -extract_json = click.option( - "-ej", - "--extract_json", - type=get_type("extract_json", ExtractionConfig), - default=get_default("extract_json", ExtractionConfig), - help="Name of the JSON file created to describe the tensor extraction. " - "Default will use format extract_{time_stamp}.json", -) - - -save_features = click.option( - "--save_features", - is_flag=True, - help="""Extract the selected mode to save the tensor. By default, the pipeline only save images and the mode extraction - is done when images are loaded in the train.""", -) - -patch_size = click.option( - "-ps", - "--patch_size", - type=get_type("patch_size", ExtractionPatchConfig), - default=get_default("patch_size", ExtractionPatchConfig), - show_default=True, - help="Patch size.", -) -stride_size = click.option( - "-ss", - "--stride_size", - type=get_type("stride_size", ExtractionPatchConfig), - default=get_default("stride_size", ExtractionPatchConfig), - show_default=True, - help="Stride size.", -) - - -slice_direction = click.option( - "-sd", - "--slice_direction", - type=get_type("slice_direction", ExtractionSliceConfig), - default=get_default("slice_direction", ExtractionSliceConfig), - show_default=True, - help="Slice direction. 0: Sagittal plane, 1: Coronal plane, 2: Axial plane.", -) -slice_mode = click.option( - "-sm", - "--slice_mode", - type=get_type("slice_mode", ExtractionSliceConfig), - default=get_default("slice_mode", ExtractionSliceConfig), - show_default=True, - help=( - "rgb: Save the slice in three identical channels, " - "single: Save the slice in a single channel." - ), -) -discarded_slices = click.option( - "-ds", - "--discarded_slices", - type=get_type("discarded_slices", ExtractionSliceConfig), - default=get_default("discarded_slices", ExtractionSliceConfig), - multiple=2, - help="""Number of slices discarded from respectively the beginning and - the end of the MRI volume. If only one argument is given, it will be - used for both sides.""", -) -roi_list = click.option( - "--roi_list", - type=get_type("roi_list", ExtractionROIConfig), - default=get_default("roi_list", ExtractionROIConfig), - required=True, - multiple=True, - help="List of regions to be extracted", -) -roi_uncrop_output = click.option( - "--roi_uncrop_output", - type=get_type("roi_uncrop_output", ExtractionROIConfig), - default=get_default("roi_uncrop_output", ExtractionROIConfig), - is_flag=True, - help="Disable cropping option so the output tensors " - "have the same size than the whole image.", -) -roi_custom_template = click.option( - "--roi_custom_template", - "-ct", - type=get_type("roi_custom_template", ExtractionROIConfig), - default=get_default("roi_custom_template", ExtractionROIConfig), - help="""Template name if MODALITY is `custom`. - Name of the template used for registration during the preprocessing procedure.""", -) -roi_custom_mask_pattern = click.option( - "--roi_custom_mask_pattern", - "-cmp", - type=get_type("roi_custom_mask_pattern", ExtractionROIConfig), - default=get_default("roi_custom_mask_pattern", ExtractionROIConfig), - help="""Mask pattern if MODALITY is `custom`. - If given will select only the masks containing the string given. - The mask with the shortest name is taken.""", -) diff --git a/clinicadl/commandline/modules_options/lr_scheduler.py b/clinicadl/commandline/modules_options/lr_scheduler.py deleted file mode 100644 index 184ce3deb..000000000 --- a/clinicadl/commandline/modules_options/lr_scheduler.py +++ /dev/null @@ -1,9 +0,0 @@ -import click - -# LR scheduler -adaptive_learning_rate = click.option( - "--adaptive_learning_rate", - "-alr", - is_flag=True, - help="Whether to diminish the learning rate", -) diff --git a/clinicadl/commandline/modules_options/maps_manager.py b/clinicadl/commandline/modules_options/maps_manager.py deleted file mode 100644 index 69574d42c..000000000 --- a/clinicadl/commandline/modules_options/maps_manager.py +++ /dev/null @@ -1,21 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.experiment_manager.config import MapsManagerConfig - -maps_dir = click.argument("maps_dir", type=get_type("maps_dir", MapsManagerConfig)) -data_group = click.option("data_group", type=get_type("data_group", MapsManagerConfig)) - - -overwrite = click.option( - "--overwrite", - "-o", - is_flag=True, - help="Will overwrite data group if existing. Please give caps_directory and participants_tsv to" - " define new data group.", -) -save_nifti = click.option( - "--save_nifti", - is_flag=True, - help="Save the output map(s) in the MAPS in NIfTI format.", -) diff --git a/clinicadl/commandline/modules_options/network.py b/clinicadl/commandline/modules_options/network.py deleted file mode 100644 index c0b8716e1..000000000 --- a/clinicadl/commandline/modules_options/network.py +++ /dev/null @@ -1,20 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.networks.old_network.config import NetworkConfig - -# Model -multi_network = click.option( - "--multi_network/--single_network", - default=get_default("multi_network", NetworkConfig), - help="If provided uses a multi-network framework.", - show_default=True, -) -dropout = click.option( - "--dropout", - type=get_type("dropout", NetworkConfig), - default=get_default("dropout", NetworkConfig), - help="Rate value applied to dropout layers in a CNN architecture.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/optimization.py b/clinicadl/commandline/modules_options/optimization.py deleted file mode 100644 index 1f1eb6296..000000000 --- a/clinicadl/commandline/modules_options/optimization.py +++ /dev/null @@ -1,30 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.optim.config import OptimizationConfig - -# Optimization -accumulation_steps = click.option( - "--accumulation_steps", - "-asteps", - type=get_type("accumulation_steps", OptimizationConfig), - default=get_default("accumulation_steps", OptimizationConfig), - help="Accumulates gradients during the given number of iterations before performing the weight update " - "in order to virtually increase the size of the batch.", - show_default=True, -) -epochs = click.option( - "--epochs", - type=get_type("epochs", OptimizationConfig), - default=get_default("epochs", OptimizationConfig), - help="Maximum number of epochs.", - show_default=True, -) -profiler = click.option( - "--profiler/--no-profiler", - default=get_default("profiler", OptimizationConfig), - help="Use `--profiler` to enable Pytorch profiler for the first 30 steps after a short warmup. " - "It will make an execution trace and some statistics about the CPU and GPU usage.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/optimizer.py b/clinicadl/commandline/modules_options/optimizer.py deleted file mode 100644 index b8a128436..000000000 --- a/clinicadl/commandline/modules_options/optimizer.py +++ /dev/null @@ -1,30 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.optim.optimizers import OptimizerConfig - -# Optimizer -learning_rate = click.option( - "--learning_rate", - "-lr", - type=get_type("learning_rate", OptimizerConfig), - default=get_default("learning_rate", OptimizerConfig), - help="Learning rate of the optimization.", - show_default=True, -) -optimizer = click.option( - "--optimizer", - type=get_type("optimizer", OptimizerConfig), - default=get_default("optimizer", OptimizerConfig), - help="Optimizer used to train the network.", - show_default=True, -) -weight_decay = click.option( - "--weight_decay", - "-wd", - type=get_type("weight_decay", OptimizerConfig), - default=get_default("weight_decay", OptimizerConfig), - help="Weight decay value used in optimization.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/preprocessing.py b/clinicadl/commandline/modules_options/preprocessing.py deleted file mode 100644 index 2d8f05ee2..000000000 --- a/clinicadl/commandline/modules_options/preprocessing.py +++ /dev/null @@ -1,73 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.dataset.preprocessing import ( - BasePreprocessing, - PreprocessingCustom, - PreprocessingDTI, - PreprocessingPET, -) - -tracer = click.option( - "--tracer", - default=get_default("tracer", PreprocessingPET), - type=get_type("tracer", PreprocessingPET), - help=( - "Acquisition label if MODALITY is `pet-linear`. " - "Name of the tracer used for the PET acquisition (trc-). " - "For instance it can be '18FFDG' for fluorodeoxyglucose or '18FAV45' for florbetapir." - ), -) -suvr_reference_region = click.option( - "-suvr", - "--suvr_reference_region", - default=get_default("suvr_reference_region", PreprocessingPET), - type=get_type("suvr_reference_region", PreprocessingPET), - help=( - "Regions used for normalization if MODALITY is `pet-linear`. " - "Intensity normalization using the average PET uptake in reference regions resulting in a standardized uptake " - "value ratio (SUVR) map. It can be cerebellumPons or cerebellumPon2 (used for amyloid tracers) or pons or " - "pons2 (used for 18F-FDG tracers)." - ), -) -custom_suffix = click.option( - "-cn", - "--custom_suffix", - default=get_default("custom_suffix", PreprocessingCustom), - type=get_type("custom_suffix", PreprocessingCustom), - help=( - "Suffix of output files if MODALITY is `custom`. " - "Suffix to append to filenames, for instance " - "`graymatter_space-Ixi549Space_modulated-off_probability.nii.gz`, or " - "`segm-whitematter_probability.nii.gz`" - ), -) -dti_measure = click.option( - "--dti_measure", - "-dm", - type=get_type("dti_measure", PreprocessingDTI), - help="Possible DTI measures.", - default=get_default("dti_measure", PreprocessingDTI), -) -dti_space = click.option( - "--dti_space", - "-ds", - type=get_type("dti_space", PreprocessingDTI), - help="Possible DTI space.", - default=get_default("dti_space", PreprocessingDTI), -) -preprocessing = click.option( - "--preprocessing", - type=get_type("preprocessing", BasePreprocessing), - default=get_default("preprocessing", BasePreprocessing), - required=True, - help="Extraction used to generate synthetic data.", - show_default=True, -) -use_uncropped_image = click.option( - "-uui", - "--use_uncropped_image", - is_flag=True, - help="Use the uncropped image instead of the cropped image generated by t1-linear or pet-linear.", -) diff --git a/clinicadl/commandline/modules_options/reproducibility.py b/clinicadl/commandline/modules_options/reproducibility.py deleted file mode 100644 index b161c0c73..000000000 --- a/clinicadl/commandline/modules_options/reproducibility.py +++ /dev/null @@ -1,43 +0,0 @@ -import click - -from clinicadl.config.config.reproducibility import ReproducibilityConfig -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type - -# Reproducibility -compensation = click.option( - "--compensation", - type=get_type("compensation", ReproducibilityConfig), - default=get_default("compensation", ReproducibilityConfig), - help="Allow the user to choose how CUDA will compensate the deterministic behaviour.", - show_default=True, -) -deterministic = click.option( - "--deterministic/--nondeterministic", - default=get_default("deterministic", ReproducibilityConfig), - help="Forces Pytorch to be deterministic even when using a GPU. " - "Will raise a RuntimeError if a non-deterministic function is encountered.", - show_default=True, -) -save_all_models = click.option( - "--save_all_models/--save_only_best_model", - type=get_type("save_all_models", ReproducibilityConfig), - default=get_default("save_all_models", ReproducibilityConfig), - help="If provided, enables the saving of models weights for each epochs.", - show_default=True, -) -seed = click.option( - "--seed", - type=get_type("seed", ReproducibilityConfig), - default=get_default("seed", ReproducibilityConfig), - help="Value to set the seed for all random operations." - "Default will sample a random value for the seed.", - show_default=True, -) -config_file = click.option( - "--config_file", - "-c", - type=get_type("config_file", ReproducibilityConfig), - default=get_default("config_file", ReproducibilityConfig), - help="Path to the TOML or JSON file containing the values of the options needed for training.", -) diff --git a/clinicadl/commandline/modules_options/split.py b/clinicadl/commandline/modules_options/split.py deleted file mode 100644 index 4579a6f7f..000000000 --- a/clinicadl/commandline/modules_options/split.py +++ /dev/null @@ -1,25 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.splitter.splitter.kfold import KFoldConfig -from clinicadl.splitter.splitter.single_split import SingleSplitConfig - -# Cross Validation -n_splits = click.option( - "--n_splits", - type=get_type("n_splits", KFoldConfig), - default=get_default("n_splits", KFoldConfig), - help="If a value is given for k will load data of a k-fold CV. " - "Default value (0) will load a single split.", - show_default=True, -) -split = click.option( - "--split", - "-s", - type=int, # get_type("split", config.ValidationConfig), - default=get_default("split", SingleSplitConfig), - multiple=True, - help="Train the list of given splits. By default, all the splits are trained.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/transforms.py b/clinicadl/commandline/modules_options/transforms.py deleted file mode 100644 index 7cdc3531d..000000000 --- a/clinicadl/commandline/modules_options/transforms.py +++ /dev/null @@ -1,22 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.transforms.config import TransformsConfig - -# Transform -data_augmentation = click.option( - "--data_augmentation", - "-da", - type=get_type("data_augmentation", TransformsConfig), - default=get_default("data_augmentation", TransformsConfig), - multiple=True, - help="Randomly applies transforms on the training set.", - show_default=True, -) -normalize = click.option( - "--normalize/--unnormalize", - default=get_default("normalize", TransformsConfig), - help="Disable default MinMaxNormalization.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py deleted file mode 100644 index 089357866..000000000 --- a/clinicadl/commandline/modules_options/validation.py +++ /dev/null @@ -1,39 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.predictor.validation import ValidationConfig - -# Validation -valid_longitudinal = click.option( - "--valid_longitudinal/--valid_baseline", - default=get_default("valid_longitudinal", ValidationConfig), - help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", - show_default=True, -) -evaluation_steps = click.option( - "--evaluation_steps", - "-esteps", - type=get_type("evaluation_steps", ValidationConfig), - default=get_default("evaluation_steps", ValidationConfig), - help="Fix the number of iterations to perform before computing an evaluation. Default will only " - "perform one evaluation at the end of each epoch.", - show_default=True, -) - -selection_metrics = click.option( - "--selection_metrics", - "-sm", - type=get_type("selection_metrics", ValidationConfig), # str list ? - default=get_default("selection_metrics", ValidationConfig), # ["loss"] - multiple=True, - help="""Allow to select a list of models based on their selection metric. Default will - only infer the result of the best model selected on loss.""", - show_default=True, -) -skip_leak_check = click.option( - "--skip_leak_check", - "-slc", - is_flag=True, - help="""Allow to skip the data leakage check usually performed. Not recommended.""", -) diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index cb25a26ff..c25107e0d 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.artifacts import options as artifacts -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateArtifactsConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/artifacts/options.py b/clinicadl/commandline/pipelines/generate/artifacts/options.py deleted file mode 100644 index 38fcd7d91..000000000 --- a/clinicadl/commandline/pipelines/generate/artifacts/options.py +++ /dev/null @@ -1,65 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.generate.generate_config import GenerateArtifactsConfig - -contrast = click.option( - "--contrast/--no-contrast", - default=get_default("contrast", GenerateArtifactsConfig), - help="", - show_default=True, -) -gamma = click.option( - "--gamma", - multiple=2, - type=get_type("gamma", GenerateArtifactsConfig), - default=get_default("gamma", GenerateArtifactsConfig), - help="Range between -1 and 1 for gamma augmentation", - show_default=True, -) -# Motion -motion = click.option( - "--motion/--no-motion", - default=get_default("motion", GenerateArtifactsConfig), - help="", - show_default=True, -) -translation = click.option( - "--translation", - multiple=2, - type=get_type("translation", GenerateArtifactsConfig), - default=get_default("translation", GenerateArtifactsConfig), - help="Range in mm for the translation", - show_default=True, -) -rotation = click.option( - "--rotation", - multiple=2, - type=get_type("rotation", GenerateArtifactsConfig), - default=get_default("rotation", GenerateArtifactsConfig), - help="Range in degree for the rotation", - show_default=True, -) -num_transforms = click.option( - "--num_transforms", - type=get_type("num_transforms", GenerateArtifactsConfig), - default=get_default("num_transforms", GenerateArtifactsConfig), - help="Number of transforms", - show_default=True, -) -# Noise -noise = click.option( - "--noise/--no-noise", - default=get_default("noise", GenerateArtifactsConfig), - help="", - show_default=True, -) -noise_std = click.option( - "--noise_std", - multiple=2, - type=get_type("noise_std", GenerateArtifactsConfig), - default=get_default("noise_std", GenerateArtifactsConfig), - help="Range for noise standard deviation", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index 3216d4ff1..d81ad9215 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -12,8 +12,8 @@ from clinicadl.commandline.pipelines.generate.hypometabolic import ( options as hypometabolic, ) -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateHypometabolicConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/options.py b/clinicadl/commandline/pipelines/generate/hypometabolic/options.py deleted file mode 100644 index aab2165e7..000000000 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/options.py +++ /dev/null @@ -1,29 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.generate.generate_config import GenerateHypometabolicConfig - -pathology = click.option( - "--pathology", - "-p", - type=get_type("pathology", GenerateHypometabolicConfig), - default=get_default("pathology", GenerateHypometabolicConfig), - help="Pathology applied. To chose in the following list: [ad, bvftd, lvppa, nfvppa, pca, svppa]", - show_default=True, -) -anomaly_degree = click.option( - "--anomaly_degree", - "-anod", - type=get_type("anomaly_degree", GenerateHypometabolicConfig), - default=get_default("anomaly_degree", GenerateHypometabolicConfig), - help="Degrees of hypo-metabolism applied (in percent)", - show_default=True, -) -sigma = click.option( - "--sigma", - type=get_type("sigma", GenerateHypometabolicConfig), - default=get_default("sigma", GenerateHypometabolicConfig), - help="It is the parameter of the gaussian filter used for smoothing.", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/generate/random/cli.py b/clinicadl/commandline/pipelines/generate/random/cli.py index 7268eb083..892999e13 100644 --- a/clinicadl/commandline/pipelines/generate/random/cli.py +++ b/clinicadl/commandline/pipelines/generate/random/cli.py @@ -14,8 +14,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.random import options as random -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateRandomConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/random/options.py b/clinicadl/commandline/pipelines/generate/random/options.py deleted file mode 100644 index 28a5b24c8..000000000 --- a/clinicadl/commandline/pipelines/generate/random/options.py +++ /dev/null @@ -1,20 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.generate.generate_config import GenerateRandomConfig - -mean = click.option( - "--mean", - type=get_type("mean", GenerateRandomConfig), - default=get_default("mean", GenerateRandomConfig), - help="Mean value of the gaussian noise added to synthetic images.", - show_default=True, -) -sigma = click.option( - "--sigma", - type=get_type("sigma", GenerateRandomConfig), - default=get_default("sigma", GenerateRandomConfig), - help="Standard deviation of the gaussian noise added to synthetic images.", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/generate/shepplogan/options.py b/clinicadl/commandline/pipelines/generate/shepplogan/options.py deleted file mode 100644 index fc15c2009..000000000 --- a/clinicadl/commandline/pipelines/generate/shepplogan/options.py +++ /dev/null @@ -1,50 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.generate.generate_config import GenerateSheppLoganConfig - -extract_json = click.option( - "-ej", - "--extract_json", - type=get_type("extract_json", GenerateSheppLoganConfig), - default=get_default("extract_json", GenerateSheppLoganConfig), - help="Name of the JSON file created to describe the tensor extraction. " - "Default will use format extract_{time_stamp}.json", - show_default=True, -) - -image_size = click.option( - "--image_size", - help="Size in pixels of the squared images.", - type=get_type("image_size", GenerateSheppLoganConfig), - default=get_default("image_size", GenerateSheppLoganConfig), - show_default=True, -) - -cn_subtypes_distribution = click.option( - "--cn_subtypes_distribution", - "-csd", - multiple=3, - type=get_type("cn_subtypes_distribution", GenerateSheppLoganConfig), - default=get_default("cn_subtypes_distribution", GenerateSheppLoganConfig), - help="Probability of each subtype to be drawn in CN label.", - show_default=True, -) - -ad_subtypes_distribution = click.option( - "--ad_subtypes_distribution", - "-asd", - multiple=3, - type=get_type("ad_subtypes_distribution", GenerateSheppLoganConfig), - default=get_default("ad_subtypes_distribution", GenerateSheppLoganConfig), - help="Probability of each subtype to be drawn in AD label.", - show_default=True, -) - -smoothing = click.option( - "--smoothing/--no-smoothing", - default=get_type("smoothing", GenerateSheppLoganConfig), - help="Adds random smoothing to generated data.", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index b8378ec17..ec9a53cb8 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.trivial import options as trivial -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateTrivialConfig from clinicadl.generate.generate_utils import ( im_loss_roi_gaussian_distribution, diff --git a/clinicadl/commandline/pipelines/generate/trivial/options.py b/clinicadl/commandline/pipelines/generate/trivial/options.py deleted file mode 100644 index b2e5c018b..000000000 --- a/clinicadl/commandline/pipelines/generate/trivial/options.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import get_args - -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.generate.generate_config import GenerateTrivialConfig - -atrophy_percent = click.option( - "--atrophy_percent", - type=get_type("atrophy_percent", GenerateTrivialConfig), - default=get_default("atrophy_percent", GenerateTrivialConfig), - help="Percentage of atrophy applied.", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/interpret/__init__.py b/clinicadl/commandline/pipelines/interpret/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/interpret/cli.py b/clinicadl/commandline/pipelines/interpret/cli.py deleted file mode 100644 index 9f4fb8a87..000000000 --- a/clinicadl/commandline/pipelines/interpret/cli.py +++ /dev/null @@ -1,56 +0,0 @@ -from pathlib import Path - -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - computational, - data, - dataloader, - maps_manager, - validation, -) -from clinicadl.commandline.pipelines.interpret import options -from clinicadl.interpret.config import InterpretConfig -from clinicadl.predictor.old_predictor import Predictor - - -@click.command("interpret", no_args_is_help=True) -@arguments.input_maps -@arguments.data_group -@maps_manager.overwrite -@maps_manager.save_nifti -@options.name -@options.method -@options.level -@options.target_node -@options.save_individual -@options.overwrite_name -@data.participants_tsv -@data.caps_directory -@data.multi_cohort -@data.diagnoses -@dataloader.n_proc -@dataloader.batch_size -@computational.gpu -@computational.amp -@validation.selection_metrics -def cli(**kwargs): - """Interpretation of trained models using saliency map method. - INPUT_MAPS_DIRECTORY is the MAPS folder from where the model to interpret will be loaded. - DATA_GROUP is the name of the subjects and sessions list used for the interpretation. - NAME is the name of the saliency map task. - METHOD is the method used to extract an attribution map. - """ - from clinicadl.utils.iotools.train_utils import merge_cli_and_maps_json_options - - dict_ = merge_cli_and_maps_json_options( - Path(kwargs["input_maps"]) / "maps.json", **kwargs - ) - interpret_config = InterpretConfig(**dict_) - predict_manager = Predictor(interpret_config) - predict_manager.interpret() - - -if __name__ == "__main__": - cli() diff --git a/clinicadl/commandline/pipelines/interpret/options.py b/clinicadl/commandline/pipelines/interpret/options.py deleted file mode 100644 index 43cada4c4..000000000 --- a/clinicadl/commandline/pipelines/interpret/options.py +++ /dev/null @@ -1,40 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.interpret.config import InterpretBaseConfig - -# interpret specific -name = click.argument( - "name", - type=get_type("name", InterpretBaseConfig), -) -method = click.argument( - "method", - type=get_type("method", InterpretBaseConfig), # ["gradients", "grad-cam"] -) -level = click.option( - "--level_grad_cam", - type=get_type("level", InterpretBaseConfig), - default=get_default("level", InterpretBaseConfig), - help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", - show_default=True, -) -target_node = click.option( - "--target_node", - type=get_type("target_node", InterpretBaseConfig), - default=get_default("target_node", InterpretBaseConfig), - help="Which target node the gradients explain. Default takes the first output node.", - show_default=True, -) -save_individual = click.option( - "--save_individual", - is_flag=True, - help="Save individual saliency maps in addition to the mean saliency map.", -) -overwrite_name = click.option( - "--overwrite_name", - "-on", - is_flag=True, - help="Overwrite the name if it already exists.", -) diff --git a/clinicadl/commandline/pipelines/predict/__init__.py b/clinicadl/commandline/pipelines/predict/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py deleted file mode 100644 index 119c12678..000000000 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ /dev/null @@ -1,69 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - computational, - data, - dataloader, - maps_manager, - split, - validation, -) -from clinicadl.commandline.pipelines.predict import options -from clinicadl.predictor.config import PredictConfig -from clinicadl.predictor.old_predictor import Predictor - - -@click.command(name="predict", no_args_is_help=True) -@arguments.input_maps -@arguments.data_group -@maps_manager.save_nifti -@maps_manager.overwrite -@options.use_labels -@data.label -@options.save_tensor -@options.save_latent_tensor -@data.caps_directory -@data.participants_tsv -@data.multi_cohort -@data.diagnoses -@validation.skip_leak_check -@validation.selection_metrics -@split.split -@computational.gpu -@computational.amp -@dataloader.n_proc -@dataloader.batch_size -def cli(input_maps_directory, data_group, **kwargs): - """This function loads a MAPS and predicts the global metrics and individual values - for all the models selected using a metric in selection_metrics. - Args: - maps_dir: path to the MAPS. - data_group: name of the data group tested. - caps_directory: path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - tsv_path: path to a TSV file containing the list of participants and sessions to interpret. - use_labels: by default is True. If False no metrics tsv files will be written. - label: Name of the target value, if different from training. - gpu: if true, it uses gpu. - amp: If enabled, uses Automatic Mixed Precision (requires GPU usage). - n_proc: num_workers used in DataLoader - batch_size: batch size of the DataLoader - selection_metrics: list of metrics to find best models to be evaluated. - diagnoses: list of diagnoses to be tested if tsv_path is a folder. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - overwrite: If True former definition of data group is erased - save_tensor: For reconstruction task only, if True it will save the reconstruction as .pt file in the MAPS. - save_nifti: For reconstruction task only, if True it will save the reconstruction as NIfTI file in the MAPS. - Infer the outputs of a trained model on a test set. - INPUT_MAPS_DIRECTORY is the MAPS folder from where the model used for prediction will be loaded. - DATA_GROUP is the name of the subjects and sessions list used for the interpretation. - """ - - predict_config = PredictConfig(**kwargs) - predict_manager = Predictor(predict_config) - predict_manager.predict() - - -if __name__ == "__main__": - cli() diff --git a/clinicadl/commandline/pipelines/predict/options.py b/clinicadl/commandline/pipelines/predict/options.py deleted file mode 100644 index cbb8980ca..000000000 --- a/clinicadl/commandline/pipelines/predict/options.py +++ /dev/null @@ -1,20 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.predictor.config import PredictConfig - -# predict specific -use_labels = click.option( - "--use_labels/--no_labels", - help="Set this option to --no_labels if your dataset does not contain ground truth labels.", -) -save_tensor = click.option( - "--save_tensor", - is_flag=True, - help="Save the reconstruction output in the MAPS in Pytorch tensor format.", -) -save_latent_tensor = click.option( - "--save_latent_tensor", - is_flag=True, - help="""Save the latent representation of the image.""", -) diff --git a/clinicadl/commandline/pipelines/prepare_data/__init__.py b/clinicadl/commandline/pipelines/prepare_data/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py deleted file mode 100644 index d162dcf97..000000000 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ /dev/null @@ -1,162 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - data, - dataloader, - extraction, - preprocessing, -) -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData -from clinicadl.utils.enum import ExtractionMethod - - -@click.command(name="image", no_args_is_help=True) -@arguments.caps_directory -@arguments.preprocessing -@dataloader.n_proc -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -@preprocessing.dti_measure -@preprocessing.dti_space -def image_cli(**kwargs): - """Extract image from nifti images. - - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - - PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - kwargs["save_features"] = True - image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.IMAGE, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(image_config) - - -@click.command(name="patch", no_args_is_help=True) -@arguments.caps_directory -@arguments.preprocessing -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -@preprocessing.dti_measure -@preprocessing.dti_space -@extraction.patch_size -@extraction.stride_size -def patch_cli(**kwargs): - """Extract patch from nifti images. - - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - - PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - - patch_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.PATCH, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(patch_config) - - -@click.command(name="slice", no_args_is_help=True) -@arguments.caps_directory -@arguments.preprocessing -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -@preprocessing.dti_measure -@preprocessing.dti_space -@extraction.slice_mode -@extraction.slice_direction -@extraction.discarded_slices -def slice_cli(**kwargs): - """Extract slice from nifti images. - - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - - PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - slice_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.SLICE, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(slice_config) - - -@click.command(name="roi", no_args_is_help=True) -@arguments.caps_directory -@arguments.preprocessing -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -@preprocessing.dti_measure -@preprocessing.dti_space -@extraction.roi_list -@extraction.roi_uncrop_output -@extraction.roi_custom_template -@extraction.roi_custom_mask_pattern -def roi_cli(**kwargs): - """Extract roi from nifti images. - - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - - PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - - roi_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.ROI, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(roi_config) - - -class RegistrationOrderGroup(click.Group): - """CLI group which lists commands by order or registration.""" - - def list_commands(self, ctx): - return self.commands.keys() - - -@click.group(cls=RegistrationOrderGroup, name="prepare-data", no_args_is_help=True) -def cli() -> None: - """Extract Pytorch tensors from nifti images.""" - pass - - -cli.add_command(image_cli) -cli.add_command(slice_cli) -cli.add_command(patch_cli) -cli.add_command(roi_cli) - - -if __name__ == "__main__": - cli() diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py deleted file mode 100644 index f4f888a71..000000000 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py +++ /dev/null @@ -1,148 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - data, - dataloader, - extraction, - preprocessing, -) -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData -from clinicadl.utils.enum import ExtractionMethod - - -@click.command(name="image", no_args_is_help=True) -@arguments.bids_directory -@arguments.caps_directory -@arguments.modality_bids -@dataloader.n_proc -@extraction.extract_json -@preprocessing.use_uncropped_image -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -@data.participants_tsv -def image_bids_cli(kwargs): - image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.IMAGE, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(image_config, from_bids=kwargs["bids_directory"]) - - -@click.command(name="patch", no_args_is_help=True) -@arguments.bids_directory -@arguments.caps_directory -@arguments.modality_bids -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@extraction.patch_size -@extraction.stride_size -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -def patch_bids_cli(kwargs): - """Extract patch from nifti images. - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - MODALITY [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - patch_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.PATCH, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(patch_config, from_bids=kwargs["bids_directory"]) - - -@click.command(name="slice", no_args_is_help=True) -@arguments.bids_directory -@arguments.caps_directory -@arguments.modality_bids -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@extraction.slice_direction -@extraction.slice_mode -@extraction.discarded_slices -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -def slice_bids_cli(kwargs): - """Extract slice from nifti images. - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - MODALITY [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - - slice_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.SLICE, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(slice_config, from_bids=kwargs["bids_directory"]) - - -@click.command(name="roi", no_args_is_help=True) -@arguments.bids_directory -@arguments.caps_directory -@arguments.modality_bids -@dataloader.n_proc -@extraction.save_features -@data.participants_tsv -@extraction.extract_json -@preprocessing.use_uncropped_image -@extraction.roi_custom_mask_pattern -@extraction.roi_custom_template -@extraction.roi_list -@extraction.roi_uncrop_output -@preprocessing.tracer -@preprocessing.suvr_reference_region -@preprocessing.custom_suffix -def roi_bids_cli(kwargs): - """Extract roi from nifti images. - - CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. - - MODALITY [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. - """ - roi_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.ROI, - preprocessing_type=kwargs["preprocessing"], - **kwargs, - ) - - DeepLearningPrepareData(roi_config, from_bids=kwargs["bids_directory"]) - - -class RegistrationOrderGroup(click.Group): - """CLI group which lists commands by order or registration.""" - - def list_commands(self, ctx): - return self.commands.keys() - - -@click.group( - cls=RegistrationOrderGroup, name="prepare-data-from-bids", no_args_is_help=True -) -def cli() -> None: - """Extract Pytorch tensors from nifti images.""" - pass - - -cli.add_command(image_bids_cli) -cli.add_command(slice_bids_cli) -cli.add_command(patch_bids_cli) -cli.add_command(roi_bids_cli) - - -if __name__ == "__main__": - cli() diff --git a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py index 938895b82..ff9be6a66 100644 --- a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py @@ -44,7 +44,7 @@ def cli( SUVR_REFERENCE_REGION is the reference region used to perform intensity normalization {pons|cerebellumPons|pons2|cerebellumPons2}. """ - from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig + from clinicadl.data.caps_dataset_config import CapsDatasetConfig from .....quality_check.pet_linear.quality_check import ( quality_check as pet_linear_qc, diff --git a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py index 6c55b3586..f4bd3bb42 100755 --- a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py @@ -2,7 +2,7 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import computational, data, dataloader -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_config import CapsDatasetConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import ExtractionMethod, Preprocessing diff --git a/clinicadl/commandline/pipelines/train/__init__.py b/clinicadl/commandline/pipelines/train/__init__.py deleted file mode 100644 index ae6937b4f..000000000 --- a/clinicadl/commandline/pipelines/train/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .classification import options -from .reconstruction import options -from .regression import options diff --git a/clinicadl/commandline/pipelines/train/classification/__init__.py b/clinicadl/commandline/pipelines/train/classification/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py deleted file mode 100644 index 21d57f365..000000000 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ /dev/null @@ -1,112 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - callbacks, - computational, - data, - dataloader, - early_stopping, - lr_scheduler, - network, - optimization, - optimizer, - reproducibility, - split, - transforms, - validation, -) -from clinicadl.commandline.pipelines.train.classification import ( - options as classification, -) -from clinicadl.commandline.pipelines.transfer_learning import ( - options as transfer_learning, -) -from clinicadl.trainer.config.classification import ClassificationConfig -from clinicadl.trainer.old_trainer import Trainer -from clinicadl.utils.enum import Task -from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options - - -@click.command(name="classification", no_args_is_help=True) -# Mandatory arguments -@arguments.caps_directory -@arguments.preprocessing_json -@arguments.tsv_path -@arguments.output_maps -# Options -# Computational -@computational.gpu -@computational.fully_sharded_data_parallel -@computational.amp -# Reproducibility -@reproducibility.seed -@reproducibility.deterministic -@reproducibility.compensation -@reproducibility.save_all_models -@reproducibility.config_file -# Model -@network.dropout -@network.multi_network -# Data -@data.multi_cohort -@data.diagnoses -@data.baseline -# validation -@validation.valid_longitudinal -@validation.evaluation_steps -# transforms -@transforms.normalize -@transforms.data_augmentation -# dataloader -@dataloader.batch_size -@dataloader.sampler -@dataloader.n_proc -# Cross validation -@split.n_splits -@split.split -# Optimization -@optimizer.optimizer -@optimizer.weight_decay -@optimizer.learning_rate -# lr scheduler -@lr_scheduler.adaptive_learning_rate -# early stopping -@early_stopping.patience -@early_stopping.tolerance -# optimization -@optimization.accumulation_steps -@optimization.profiler -@optimization.epochs -# transfer learning -@transfer_learning.transfer_path -@transfer_learning.transfer_selection_metric -@transfer_learning.nb_unfrozen_layer -# callbacks -@callbacks.emissions_calculator -@callbacks.track_exp -# Task-related -@classification.architecture -@classification.label -@classification.selection_metrics -@classification.threshold -@classification.loss -def cli(**kwargs): - """ - Train a deep learning model to learn a classification task on neuroimaging data. - CAPS_DIRECTORY is the CAPS folder from where tensors will be loaded. - PREPROCESSING_JSON is the name of the JSON file in CAPS_DIRECTORY/tensor_extraction folder where - all information about extraction are stored in order to read the wanted tensors. - TSV_DIRECTORY is a folder were TSV files defining train and validation sets are stored. - OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved. - Options for this command can be input by declaring argument on the command line or by providing a - configuration file in TOML format. For more details, please visit the documentation: - https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file - - """ - - options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) - config = ClassificationConfig(**options) - trainer = Trainer(config) - - trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/classification/options.py b/clinicadl/commandline/pipelines/train/classification/options.py deleted file mode 100644 index b74bb430a..000000000 --- a/clinicadl/commandline/pipelines/train/classification/options.py +++ /dev/null @@ -1,52 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.trainer.config.classification import ( - DataConfig, - NetworkConfig, - ValidationConfig, -) - -# Data -label = click.option( - "--label", - type=get_type("label", DataConfig), - default=get_default("label", DataConfig), - help="Target label used for training.", - show_default=True, -) -# Model -architecture = click.option( - "-a", - "--architecture", - type=get_type("architecture", NetworkConfig), - default=get_default("architecture", NetworkConfig), - help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", -) -loss = click.option( - "--loss", - "-l", - type=get_type("loss", NetworkConfig), - default=get_default("loss", NetworkConfig), - help="Loss used by the network to optimize its training task.", - show_default=True, -) -threshold = click.option( - "--selection_threshold", - type=get_type("selection_threshold", NetworkConfig), - default=get_default("selection_threshold", NetworkConfig), - help="""Selection threshold for soft-voting. Will only be used if num_networks > 1.""", - show_default=True, -) -# Validation -selection_metrics = click.option( - "--selection_metrics", - "-sm", - multiple=True, - type=get_type("selection_metrics", ValidationConfig), - default=get_default("selection_metrics", ValidationConfig), - help="""Allow to save a list of models based on their selection metric. Default will - only save the best model selected on loss.""", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/train/cli.py b/clinicadl/commandline/pipelines/train/cli.py deleted file mode 100644 index 8e536efd2..000000000 --- a/clinicadl/commandline/pipelines/train/cli.py +++ /dev/null @@ -1,26 +0,0 @@ -import click - -from .classification.cli import cli as classification_cli -from .from_json.cli import cli as from_json_cli -from .list_models.cli import cli as list_models_cli -from .reconstruction.cli import cli as reconstruction_cli -from .regression.cli import cli as regression_cli -from .resume.cli import cli as resume_cli - - -@click.group(name="train", no_args_is_help=True) -def cli(): - """Train a deep learning model for a specific task.""" - pass - - -cli.add_command(classification_cli) -cli.add_command(regression_cli) -cli.add_command(reconstruction_cli) -cli.add_command(from_json_cli) -cli.add_command(resume_cli) -cli.add_command(list_models_cli) - - -if __name__ == "__main__": - cli() diff --git a/clinicadl/commandline/pipelines/train/from_json/__init__.py b/clinicadl/commandline/pipelines/train/from_json/__init__.py deleted file mode 100644 index 4b7029b4e..000000000 --- a/clinicadl/commandline/pipelines/train/from_json/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cli import cli diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py deleted file mode 100644 index 5e6771258..000000000 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ /dev/null @@ -1,34 +0,0 @@ -from logging import getLogger - -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - split, -) -from clinicadl.trainer.old_trainer import Trainer - - -@click.command(name="from_json", no_args_is_help=True) -@arguments.config_file -@arguments.output_maps -@split.split -def cli(**kwargs): - """ - Replicate a deep learning training based on a previously created JSON file. - This is particularly useful to retrain random architectures obtained with a random search. - - CONFIG_JSON is the path to the JSON file with the configuration of the training procedure. - - OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved. - """ - - logger = getLogger("clinicadl") - logger.info(f"Reading JSON file at path {kwargs['config_file']}...") - - trainer = Trainer.from_json( - config_file=kwargs["config_file"], - maps_path=kwargs["output_maps_directory"], - split=kwargs["split"], - ) - trainer.train(split_list=kwargs["split"], overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/list_models/__init__.py b/clinicadl/commandline/pipelines/train/list_models/__init__.py deleted file mode 100644 index 4b7029b4e..000000000 --- a/clinicadl/commandline/pipelines/train/list_models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cli import cli diff --git a/clinicadl/commandline/pipelines/train/list_models/cli.py b/clinicadl/commandline/pipelines/train/list_models/cli.py deleted file mode 100644 index 95632aefc..000000000 --- a/clinicadl/commandline/pipelines/train/list_models/cli.py +++ /dev/null @@ -1,33 +0,0 @@ -import click - - -@click.command(name="list_models") -@click.option( - "-a", - "--architecture", - type=str, - help="Name of the network for which information will be displayed.", -) -@click.option( - "-i", - "--input_size", - type=str, - help="Size of the input image in the shape C@HxW if the image is 2D or C@DxHxW if the image is 3D.", -) -@click.option( - "-m", - "--model_layers", - type=bool, - default=False, - is_flag=True, - help="Display the detailed Pytorch architecture.", -) -def cli( - architecture, - input_size, - model_layers, -): - """Show the list of available models in ClinicaDL.""" - from clinicadl.utils.iotools.train_utils import get_model_list - - get_model_list(architecture, input_size, model_layers) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/__init__.py b/clinicadl/commandline/pipelines/train/reconstruction/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py deleted file mode 100644 index 1bad88443..000000000 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ /dev/null @@ -1,108 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - callbacks, - computational, - data, - dataloader, - early_stopping, - lr_scheduler, - network, - optimization, - optimizer, - reproducibility, - split, - transforms, - validation, -) -from clinicadl.commandline.pipelines.train.reconstruction import ( - options as reconstruction, -) -from clinicadl.commandline.pipelines.transfer_learning import ( - options as transfer_learning, -) -from clinicadl.trainer.config.reconstruction import ReconstructionConfig -from clinicadl.trainer.old_trainer import Trainer -from clinicadl.utils.enum import Task -from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options - - -@click.command(name="reconstruction", no_args_is_help=True) -# Mandatory arguments -@arguments.caps_directory -@arguments.preprocessing_json -@arguments.tsv_path -@arguments.output_maps -# Options -# Computational -@computational.gpu -@computational.fully_sharded_data_parallel -@computational.amp -# Reproducibility -@reproducibility.seed -@reproducibility.deterministic -@reproducibility.compensation -@reproducibility.save_all_models -@reproducibility.config_file -# Model -@network.dropout -@network.multi_network -# Data -@data.multi_cohort -@data.diagnoses -@data.baseline -# validation -@validation.valid_longitudinal -@validation.evaluation_steps -# transforms -@transforms.normalize -@transforms.data_augmentation -# dataloader -@dataloader.batch_size -@dataloader.sampler -@dataloader.n_proc -# Cross validation -@split.n_splits -@split.split -# Optimization -@optimizer.optimizer -@optimizer.weight_decay -@optimizer.learning_rate -# lr scheduler -@lr_scheduler.adaptive_learning_rate -# early stopping -@early_stopping.patience -@early_stopping.tolerance -# optimization -@optimization.accumulation_steps -@optimization.profiler -@optimization.epochs -# transfer learning -@transfer_learning.transfer_path -@transfer_learning.transfer_selection_metric -@transfer_learning.nb_unfrozen_layer -# callbacks -@callbacks.emissions_calculator -@callbacks.track_exp -# Task-related -@reconstruction.architecture -@reconstruction.selection_metrics -@reconstruction.loss -def cli(**kwargs): - """ - Train a deep learning model to learn a reconstruction task on neuroimaging data. - CAPS_DIRECTORY is the CAPS folder from where tensors will be loaded. - PREPROCESSING_JSON is the name of the JSON file in CAPS_DIRECTORY/tensor_extraction folder where - all information about extraction are stored in order to read the wanted tensors. - TSV_DIRECTORY is a folder were TSV files defining train and validation sets are stored. - OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved. - Options for this command can be input by declaring argument on the command line or by providing a - configuration file in TOML format. For more details, please visit the documentation: - https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file - """ - - options = merge_cli_and_config_file_options(Task.RECONSTRUCTION, **kwargs) - config = ReconstructionConfig(**options) - trainer = Trainer(config) - trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/options.py b/clinicadl/commandline/pipelines/train/reconstruction/options.py deleted file mode 100644 index b78a549fb..000000000 --- a/clinicadl/commandline/pipelines/train/reconstruction/options.py +++ /dev/null @@ -1,36 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.trainer.config.reconstruction import ( - NetworkConfig, - ValidationConfig, -) - -# Model -architecture = click.option( - "-a", - "--architecture", - type=get_type("architecture", NetworkConfig), - default=get_default("architecture", NetworkConfig), - help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", -) -loss = click.option( - "--loss", - "-l", - type=get_type("loss", NetworkConfig), - default=get_default("loss", NetworkConfig), - help="Loss used by the network to optimize its training task.", - show_default=True, -) -# Validation -selection_metrics = click.option( - "--selection_metrics", - "-sm", - multiple=True, - type=get_type("selection_metrics", ValidationConfig), - default=get_default("selection_metrics", ValidationConfig), - help="""Allow to save a list of models based on their selection metric. Default will - only save the best model selected on loss.""", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/train/regression/__init__.py b/clinicadl/commandline/pipelines/train/regression/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py deleted file mode 100644 index 95a623604..000000000 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ /dev/null @@ -1,107 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - callbacks, - computational, - data, - dataloader, - early_stopping, - lr_scheduler, - network, - optimization, - optimizer, - reproducibility, - split, - transforms, - validation, -) -from clinicadl.commandline.pipelines.train.regression import options as regression -from clinicadl.commandline.pipelines.transfer_learning import ( - options as transfer_learning, -) -from clinicadl.trainer.config.regression import RegressionConfig -from clinicadl.trainer.old_trainer import Trainer -from clinicadl.utils.enum import Task -from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options - - -@click.command(name="regression", no_args_is_help=True) -# Mandatory arguments -@arguments.caps_directory -@arguments.preprocessing_json -@arguments.tsv_path -@arguments.output_maps -# Options -# Computational -@computational.gpu -@computational.fully_sharded_data_parallel -@computational.amp -# Reproducibility -@reproducibility.seed -@reproducibility.deterministic -@reproducibility.compensation -@reproducibility.save_all_models -@reproducibility.config_file -# Model -@network.dropout -@network.multi_network -# Data -@data.multi_cohort -@data.diagnoses -@data.baseline -# validation -@validation.valid_longitudinal -@validation.evaluation_steps -# transforms -@transforms.normalize -@transforms.data_augmentation -# dataloader -@dataloader.batch_size -@dataloader.sampler -@dataloader.n_proc -# Cross validation -@split.n_splits -@split.split -# Optimization -@optimizer.optimizer -@optimizer.weight_decay -@optimizer.learning_rate -# lr scheduler -@lr_scheduler.adaptive_learning_rate -# early stopping -@early_stopping.patience -@early_stopping.tolerance -# optimization -@optimization.accumulation_steps -@optimization.profiler -@optimization.epochs -# transfer learning -@transfer_learning.transfer_path -@transfer_learning.transfer_selection_metric -@transfer_learning.nb_unfrozen_layer -# callbacks -@callbacks.emissions_calculator -@callbacks.track_exp -# Task-related -@regression.architecture -@regression.label -@regression.selection_metrics -@regression.loss -def cli(**kwargs): - """ - Train a deep learning model to learn a regression task on neuroimaging data. - CAPS_DIRECTORY is the CAPS folder from where tensors will be loaded. - PREPROCESSING_JSON is the name of the JSON file in CAPS_DIRECTORY/tensor_extraction folder where - all information about extraction are stored in order to read the wanted tensors. - TSV_DIRECTORY is a folder were TSV files defining train and validation sets are stored. - OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved. - Options for this command can be input by declaring argument on the command line or by providing a - configuration file in TOML format. For more details, please visit the documentation: - https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file - """ - - options = merge_cli_and_config_file_options(Task.REGRESSION, **kwargs) - config = RegressionConfig(**options) - trainer = Trainer(config) - trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/regression/options.py b/clinicadl/commandline/pipelines/train/regression/options.py deleted file mode 100644 index 5e33ea187..000000000 --- a/clinicadl/commandline/pipelines/train/regression/options.py +++ /dev/null @@ -1,45 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.trainer.config.regression import ( - DataConfig, - NetworkConfig, - ValidationConfig, -) - -# Data -label = click.option( - "--label", - type=get_type("label", DataConfig), - default=get_default("label", DataConfig), - help="Target label used for training.", - show_default=True, -) -# Model -architecture = click.option( - "-a", - "--architecture", - type=get_type("architecture", NetworkConfig), - default=get_default("architecture", NetworkConfig), - help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", -) -loss = click.option( - "--loss", - "-l", - type=get_type("loss", NetworkConfig), - default=get_default("loss", NetworkConfig), - help="Loss used by the network to optimize its training task.", - show_default=True, -) -# Validation -selection_metrics = click.option( - "--selection_metrics", - "-sm", - multiple=True, - type=get_type("selection_metrics", ValidationConfig), - default=get_default("selection_metrics", ValidationConfig), - help="""Allow to save a list of models based on their selection metric. Default will - only save the best model selected on loss.""", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/train/resume/__init__.py b/clinicadl/commandline/pipelines/train/resume/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py deleted file mode 100644 index 90efa4244..000000000 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ /dev/null @@ -1,19 +0,0 @@ -import click - -from clinicadl.commandline import arguments -from clinicadl.commandline.modules_options import ( - split, -) -from clinicadl.trainer.old_trainer import Trainer - - -@click.command(name="resume", no_args_is_help=True) -@arguments.input_maps -@split.split -def cli(input_maps_directory, split): - """Resume training job in specified maps. - - INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started. - """ - trainer = Trainer.from_maps(input_maps_directory) - trainer.resume() diff --git a/clinicadl/commandline/pipelines/transfer_learning/__init__.py b/clinicadl/commandline/pipelines/transfer_learning/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/commandline/pipelines/transfer_learning/options.py b/clinicadl/commandline/pipelines/transfer_learning/options.py deleted file mode 100644 index 870f3e66b..000000000 --- a/clinicadl/commandline/pipelines/transfer_learning/options.py +++ /dev/null @@ -1,30 +0,0 @@ -import click - -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.trainer.transfer_learning import TransferLearningConfig - -nb_unfrozen_layer = click.option( - "-nul", - "--nb_unfrozen_layer", - type=get_type("nb_unfrozen_layer", TransferLearningConfig), - default=get_default("nb_unfrozen_layer", TransferLearningConfig), - help="Number of layer that will be retrain during training. For example, if it is 2, the last two layers of the model will not be freezed.", - show_default=True, -) -transfer_path = click.option( - "-tp", - "--transfer_path", - type=get_type("transfer_path", TransferLearningConfig), - default=get_default("transfer_path", TransferLearningConfig), - help="Path of to a MAPS used for transfer learning.", - show_default=True, -) -transfer_selection_metric = click.option( - "-tsm", - "--transfer_selection_metric", - type=get_type("transfer_selection_metric", TransferLearningConfig), - default=get_default("transfer_selection_metric", TransferLearningConfig), - help="Metric used to select the model for transfer learning in the MAPS defined by transfer_path.", - show_default=True, -) diff --git a/clinicadl/config/config/__init__.py b/clinicadl/config/config/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/config/config/lr_scheduler.py b/clinicadl/config/config/lr_scheduler.py deleted file mode 100644 index 75ffe86ea..000000000 --- a/clinicadl/config/config/lr_scheduler.py +++ /dev/null @@ -1,13 +0,0 @@ -from logging import getLogger - -from pydantic import BaseModel, ConfigDict - -logger = getLogger("clinicadl.lr_config") - - -class LRschedulerConfig(BaseModel): - """Config class to instantiate an LR Scheduler.""" - - adaptive_learning_rate: bool = False - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/config/config/reproducibility.py b/clinicadl/config/config/reproducibility.py deleted file mode 100644 index 2926f3fbc..000000000 --- a/clinicadl/config/config/reproducibility.py +++ /dev/null @@ -1,21 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Optional - -from pydantic import BaseModel, ConfigDict - -from clinicadl.utils.enum import Compensation - -logger = getLogger("clinicadl.reproducibility_config") - - -class ReproducibilityConfig(BaseModel): - """Config class to handle reproducibility parameters.""" - - compensation: Compensation = Compensation.MEMORY - deterministic: bool = False - save_all_models: bool = False - seed: int = 0 - config_file: Optional[Path] = None - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/commandline/modules_options/__init__.py b/clinicadl/data/__init__.py similarity index 100% rename from clinicadl/commandline/modules_options/__init__.py rename to clinicadl/data/__init__.py diff --git a/clinicadl/dataset/config/__init__.py b/clinicadl/data/config/__init__.py similarity index 100% rename from clinicadl/dataset/config/__init__.py rename to clinicadl/data/config/__init__.py diff --git a/clinicadl/dataset/config/data.py b/clinicadl/data/config/data.py similarity index 100% rename from clinicadl/dataset/config/data.py rename to clinicadl/data/config/data.py diff --git a/clinicadl/dataset/config/file_type.py b/clinicadl/data/config/file_type.py similarity index 100% rename from clinicadl/dataset/config/file_type.py rename to clinicadl/data/config/file_type.py diff --git a/clinicadl/dataset/dataloader/__init__.py b/clinicadl/data/dataloader/__init__.py similarity index 100% rename from clinicadl/dataset/dataloader/__init__.py rename to clinicadl/data/dataloader/__init__.py diff --git a/clinicadl/dataset/dataloader/config.py b/clinicadl/data/dataloader/config.py similarity index 98% rename from clinicadl/dataset/dataloader/config.py rename to clinicadl/data/dataloader/config.py index 7c4acfb6d..a21cece3e 100644 --- a/clinicadl/dataset/dataloader/config.py +++ b/clinicadl/data/dataloader/config.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler from torch.utils.data import WeightedRandomSampler as BaseWeightedRandomSampler -from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets import CapsDataset from clinicadl.utils.config import ClinicaDLConfig from clinicadl.utils.seed import pl_worker_init_function diff --git a/clinicadl/dataset/dataloader/defaults.py b/clinicadl/data/dataloader/defaults.py similarity index 100% rename from clinicadl/dataset/dataloader/defaults.py rename to clinicadl/data/dataloader/defaults.py diff --git a/clinicadl/dataset/datasets/___init__.py b/clinicadl/data/datasets/__init__.py similarity index 100% rename from clinicadl/dataset/datasets/___init__.py rename to clinicadl/data/datasets/__init__.py diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/data/datasets/caps_dataset.py similarity index 98% rename from clinicadl/dataset/datasets/caps_dataset.py rename to clinicadl/data/datasets/caps_dataset.py index 5a877c8d2..8f2bc0f5f 100644 --- a/clinicadl/dataset/datasets/caps_dataset.py +++ b/clinicadl/data/datasets/caps_dataset.py @@ -16,9 +16,9 @@ from tqdm import tqdm from typing_extensions import Self -from clinicadl.dataset.preprocessing import BasePreprocessing -from clinicadl.dataset.readers.caps_reader import CapsReader -from clinicadl.dataset.utils import ( +from clinicadl.data.preprocessing import BasePreprocessing +from clinicadl.data.readers import CapsReader +from clinicadl.data.utils import ( CapsDatasetSample, check_df, get_infos_from_json, @@ -416,7 +416,7 @@ def __getitem__(self, idx: NonNegativeInt) -> CapsDatasetSample: image = image_augmentation(image) if not isinstance(self.extraction, Image): - tensor = self.transforms.extraction.extract_tensor( + tensor = self.transforms.extraction.extract_sample( image, elem_index, ) diff --git a/clinicadl/dataset/datasets/concat.py b/clinicadl/data/datasets/concat.py similarity index 93% rename from clinicadl/dataset/datasets/concat.py rename to clinicadl/data/datasets/concat.py index 345fbe5b4..ea53f4bc9 100644 --- a/clinicadl/dataset/datasets/concat.py +++ b/clinicadl/data/datasets/concat.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.utils import CapsDatasetSample +from clinicadl.data.datasets import CapsDataset +from clinicadl.data.utils import CapsDatasetSample logger = getLogger("clinicadl") diff --git a/clinicadl/dataset/preprocessing/__init__.py b/clinicadl/data/preprocessing/__init__.py similarity index 100% rename from clinicadl/dataset/preprocessing/__init__.py rename to clinicadl/data/preprocessing/__init__.py diff --git a/clinicadl/dataset/preprocessing/base.py b/clinicadl/data/preprocessing/base.py similarity index 100% rename from clinicadl/dataset/preprocessing/base.py rename to clinicadl/data/preprocessing/base.py diff --git a/clinicadl/dataset/preprocessing/custom.py b/clinicadl/data/preprocessing/custom.py similarity index 93% rename from clinicadl/dataset/preprocessing/custom.py rename to clinicadl/data/preprocessing/custom.py index dd3ee8c57..4e8e9cfff 100644 --- a/clinicadl/dataset/preprocessing/custom.py +++ b/clinicadl/data/preprocessing/custom.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import Preprocessing from clinicadl.utils.iotools.clinica_utils import FileType diff --git a/clinicadl/dataset/preprocessing/dti.py b/clinicadl/data/preprocessing/dti.py similarity index 95% rename from clinicadl/dataset/preprocessing/dti.py rename to clinicadl/data/preprocessing/dti.py index f0852b4eb..2e8b6e498 100644 --- a/clinicadl/dataset/preprocessing/dti.py +++ b/clinicadl/data/preprocessing/dti.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Optional -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import ( DTIMeasure, DTISpace, diff --git a/clinicadl/dataset/preprocessing/flair.py b/clinicadl/data/preprocessing/flair.py similarity index 92% rename from clinicadl/dataset/preprocessing/flair.py rename to clinicadl/data/preprocessing/flair.py index 3b94bcefc..477451d37 100644 --- a/clinicadl/dataset/preprocessing/flair.py +++ b/clinicadl/data/preprocessing/flair.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Optional -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import LinearModality, Preprocessing from clinicadl.utils.iotools.clinica_utils import FileType diff --git a/clinicadl/dataset/preprocessing/pet.py b/clinicadl/data/preprocessing/pet.py similarity index 96% rename from clinicadl/dataset/preprocessing/pet.py rename to clinicadl/data/preprocessing/pet.py index c299c9c1a..36e54bde5 100644 --- a/clinicadl/dataset/preprocessing/pet.py +++ b/clinicadl/data/preprocessing/pet.py @@ -3,7 +3,7 @@ from pydantic import field_validator -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import ( Preprocessing, SUVRReferenceRegions, diff --git a/clinicadl/dataset/preprocessing/t1.py b/clinicadl/data/preprocessing/t1.py similarity index 92% rename from clinicadl/dataset/preprocessing/t1.py rename to clinicadl/data/preprocessing/t1.py index 4c557b4a1..879a0eff3 100644 --- a/clinicadl/dataset/preprocessing/t1.py +++ b/clinicadl/data/preprocessing/t1.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Optional -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import LinearModality, Preprocessing from clinicadl.utils.iotools.clinica_utils import FileType diff --git a/clinicadl/dataset/preprocessing/t2.py b/clinicadl/data/preprocessing/t2.py similarity index 93% rename from clinicadl/dataset/preprocessing/t2.py rename to clinicadl/data/preprocessing/t2.py index 127fc76c2..c1575c009 100644 --- a/clinicadl/dataset/preprocessing/t2.py +++ b/clinicadl/data/preprocessing/t2.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Optional -from clinicadl.dataset.preprocessing.base import BasePreprocessing +from clinicadl.data.preprocessing.base import BasePreprocessing from clinicadl.utils.enum import LinearModality, Preprocessing from clinicadl.utils.iotools.clinica_utils import FileType diff --git a/clinicadl/dataset/readers/__init__.py b/clinicadl/data/readers/__init__.py similarity index 100% rename from clinicadl/dataset/readers/__init__.py rename to clinicadl/data/readers/__init__.py diff --git a/clinicadl/dataset/readers/bids_reader.py b/clinicadl/data/readers/bids_reader.py similarity index 98% rename from clinicadl/dataset/readers/bids_reader.py rename to clinicadl/data/readers/bids_reader.py index 3abbacd67..fa2294d5f 100644 --- a/clinicadl/dataset/readers/bids_reader.py +++ b/clinicadl/data/readers/bids_reader.py @@ -2,8 +2,8 @@ from logging import getLogger from pathlib import Path -from clinicadl.dataset.config import FileType -from clinicadl.dataset.utils import insensitive_glob +from clinicadl.data.config import FileType +from clinicadl.data.utils import insensitive_glob from clinicadl.utils.exceptions import ClinicaDLBIDSError from .reader import Reader diff --git a/clinicadl/dataset/readers/caps_reader.py b/clinicadl/data/readers/caps_reader.py similarity index 98% rename from clinicadl/dataset/readers/caps_reader.py rename to clinicadl/data/readers/caps_reader.py index 506e0c7ac..523e2b691 100644 --- a/clinicadl/dataset/readers/caps_reader.py +++ b/clinicadl/data/readers/caps_reader.py @@ -4,9 +4,9 @@ import pandas as pd -from clinicadl.dataset.preprocessing import BasePreprocessing -from clinicadl.dataset.readers.reader import Reader -from clinicadl.dataset.utils import insensitive_glob +from clinicadl.data.preprocessing import BasePreprocessing +from clinicadl.data.readers.reader import Reader +from clinicadl.data.utils import insensitive_glob from clinicadl.transforms.transforms import Transforms from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ( diff --git a/clinicadl/dataset/readers/multi_caps_reader.py b/clinicadl/data/readers/multi_caps_reader.py similarity index 100% rename from clinicadl/dataset/readers/multi_caps_reader.py rename to clinicadl/data/readers/multi_caps_reader.py diff --git a/clinicadl/dataset/readers/reader.py b/clinicadl/data/readers/reader.py similarity index 100% rename from clinicadl/dataset/readers/reader.py rename to clinicadl/data/readers/reader.py diff --git a/clinicadl/dataset/utils.py b/clinicadl/data/utils.py similarity index 99% rename from clinicadl/dataset/utils.py rename to clinicadl/data/utils.py index a1f2c1889..1d8dd614c 100644 --- a/clinicadl/dataset/utils.py +++ b/clinicadl/data/utils.py @@ -9,14 +9,14 @@ import torchio as tio from pydantic import BaseModel, ConfigDict -from clinicadl.dataset import preprocessing +from clinicadl.data import preprocessing from clinicadl.transforms import extraction from clinicadl.transforms.transforms import Transforms from clinicadl.utils.enum import ExtractionMethod, Preprocessing from clinicadl.utils.exceptions import ClinicaDLTSVError from clinicadl.utils.iotools.utils import read_preprocessing -logger = getLogger("clinicadl.dataset.utils") +logger = getLogger("clinicadl.data.utils") PARTICIPANT_ID = "participant_id" SESSION_ID = "session_id" diff --git a/clinicadl/dataset/__init__.py b/clinicadl/dataset/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/clinicadl/experiment_manager/__init__.py b/clinicadl/experiment_manager/__init__.py index e69de29bb..294f94337 100644 --- a/clinicadl/experiment_manager/__init__.py +++ b/clinicadl/experiment_manager/__init__.py @@ -0,0 +1 @@ +from .experiment_manager import ExperimentManager diff --git a/clinicadl/experiment_manager/experiment_manager.py b/clinicadl/experiment_manager/experiment_manager.py index dea504e6e..9948cfa19 100644 --- a/clinicadl/experiment_manager/experiment_manager.py +++ b/clinicadl/experiment_manager/experiment_manager.py @@ -8,8 +8,8 @@ import pandas as pd from pydantic import BaseModel -from clinicadl.dataset.preprocessing import BasePreprocessing -from clinicadl.dataset.readers import CapsReader +from clinicadl.data.preprocessing import BasePreprocessing +from clinicadl.data.readers import CapsReader from clinicadl.metrics.old_metrics.utils import check_selection_metric from clinicadl.model.clinicadl_model import ClinicaDLModel from clinicadl.networks.config import NetworkConfig diff --git a/clinicadl/experiment_manager/maps_manager.py b/clinicadl/experiment_manager/maps_manager.py index 0a3903136..792ae15b0 100644 --- a/clinicadl/experiment_manager/maps_manager.py +++ b/clinicadl/experiment_manager/maps_manager.py @@ -291,7 +291,7 @@ def _write_requirements_version(self): def _write_training_data(self): """Writes the TSV file containing the participant and session IDs used for training.""" logger.debug("Writing training data...") - from clinicadl.dataset.data_utils import load_data_test + from clinicadl.data.data_utils import load_data_test train_df = load_data_test( self.tsv_path, diff --git a/clinicadl/hugging_face/hugging_face.py b/clinicadl/hugging_face/hugging_face.py index 22b6bbb02..8b3536a5f 100644 --- a/clinicadl/hugging_face/hugging_face.py +++ b/clinicadl/hugging_face/hugging_face.py @@ -5,7 +5,7 @@ import toml -from clinicadl.dataset.caps_dataset_utils import read_json +from clinicadl.data.caps_dataset_utils import read_json from clinicadl.utils.exceptions import ClinicaDLArgumentError from clinicadl.utils.iotools.maps_manager_utils import ( remove_unused_tasks, diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index 5a1ebeaf3..4bf168792 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, field_validator -from clinicadl.dataset.config.data import DataConfig +from clinicadl.data.config.data import DataConfig from clinicadl.experiment_manager.config import ( MapsManagerConfig as MapsManagerConfigBase, ) diff --git a/clinicadl/model/clinicadl_model.py b/clinicadl/model/clinicadl_model.py index 27adad432..b8ee21a8d 100644 --- a/clinicadl/model/clinicadl_model.py +++ b/clinicadl/model/clinicadl_model.py @@ -1,8 +1,61 @@ +from pathlib import Path + +import torch import torch.nn as nn -import torch.optim as optim +from torch.optim.optimizer import Optimizer + +from clinicadl.losses import get_loss_function_from_config +from clinicadl.losses.config import LossConfig +from clinicadl.losses.utils import Loss +from clinicadl.networks import get_network_from_config +from clinicadl.networks.config import NetworkConfig +from clinicadl.optim import get_optimizer_from_config +from clinicadl.optim.optimizers import OptimizerConfig +from clinicadl.utils.computational.ddp import DDP class ClinicaDLModel: - def __init__(self, network: nn.Module, loss: nn.Module, optimizer): + def __init__(self, network: nn.Module, loss: Loss, optimizer: Optimizer): + self.network = network + self.loss = loss + self.optimizer = optimizer + + self.network = DDP( + self.network, + fsdp=fully_sharded_data_parallel, + amp=amp, + ) # to check + + @classmethod + def from_config( + cls, + network_config: NetworkConfig, + loss_config: LossConfig, + optimizer_config: OptimizerConfig, + ): + loss, _ = get_loss_function_from_config(loss_config) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer_from_config(optimizer_config, network) + + return ClinicaDLModel(network, loss, optimizer) + + def load_optim_state_dict(self, optimizer_path: Path): + checkpoint_state = torch.load( + optimizer_path, map_location=self.network.device, weights_only=True + ) + self.network.load_optim_state_dict( + self.optimizer, checkpoint_state["optimizer"] + ) + + def load_state_dict(self, checkpoint_path: Path): + checkpoint_state = torch.load( + checkpoint_path, map_location=self.network.device, weights_only=True + ) + self.network.load_state_dict(checkpoint_state["model"]) + + return checkpoint_state["epoch"] + + def _init_from_maps(self, maps_path: Path): """TO COMPLETE""" + # Load network and optimizer from maps_path, for tranqfer learning pass diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py index 2c793045c..89af868f2 100644 --- a/clinicadl/predictor/config.py +++ b/clinicadl/predictor/config.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, computed_field -from clinicadl.dataset.config.data import DataConfig as DataBaseConfig +from clinicadl.data.config.data import DataConfig as DataBaseConfig from clinicadl.experiment_manager.config import ( MapsManagerConfig as MapsManagerBaseConfig, ) diff --git a/clinicadl/predictor/old_predictor.py b/clinicadl/predictor/old_predictor.py index 96d4764a5..f6ddf1377 100644 --- a/clinicadl/predictor/old_predictor.py +++ b/clinicadl/predictor/old_predictor.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from clinicadl.dataset.caps_dataset import ( +from clinicadl.data.caps_dataset import ( return_dataset, ) from clinicadl.experiment_manager.maps_manager import MapsManager diff --git a/clinicadl/predictor/predictor.py b/clinicadl/predictor/predictor.py index f173e3dde..78ba6eb29 100644 --- a/clinicadl/predictor/predictor.py +++ b/clinicadl/predictor/predictor.py @@ -1,4 +1,4 @@ -from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.data.caps_dataset import CapsDataset from clinicadl.experiment_manager.experiment_manager import ExperimentManager diff --git a/clinicadl/quality_check/pet_linear/quality_check.py b/clinicadl/quality_check/pet_linear/quality_check.py index 87edcc7c1..fbc0e0a0c 100644 --- a/clinicadl/quality_check/pet_linear/quality_check.py +++ b/clinicadl/quality_check/pet_linear/quality_check.py @@ -12,8 +12,8 @@ import pandas as pd from joblib import Parallel, delayed -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.utils import pet_linear_nii +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.utils import pet_linear_nii from clinicadl.utils.iotools.clinica_utils import ( RemoteFileStructure, clinicadl_file_reader, diff --git a/clinicadl/quality_check/t1_linear/quality_check.py b/clinicadl/quality_check/t1_linear/quality_check.py index 373f5228c..85158c506 100755 --- a/clinicadl/quality_check/t1_linear/quality_check.py +++ b/clinicadl/quality_check/t1_linear/quality_check.py @@ -11,7 +11,7 @@ from torch.amp import autocast from torch.utils.data import DataLoader -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_config import CapsDatasetConfig from clinicadl.generate.generate_utils import load_and_check_tsv from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.exceptions import ClinicaDLArgumentError diff --git a/clinicadl/quality_check/t1_linear/utils.py b/clinicadl/quality_check/t1_linear/utils.py index 71d2cb1d2..8b778e367 100755 --- a/clinicadl/quality_check/t1_linear/utils.py +++ b/clinicadl/quality_check/t1_linear/utils.py @@ -8,9 +8,9 @@ import torch from torch.utils.data import Dataset -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.dataset.utils import linear_nii +from clinicadl.data.caps_dataset_config import CapsDatasetConfig +from clinicadl.data.caps_dataset_utils import compute_folder_and_file_type +from clinicadl.data.utils import linear_nii from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ClinicaDLException from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader diff --git a/clinicadl/splitter/make_splits/kfold.py b/clinicadl/splitter/make_splits/kfold.py index ed9dd41e2..64235f5ea 100644 --- a/clinicadl/splitter/make_splits/kfold.py +++ b/clinicadl/splitter/make_splits/kfold.py @@ -6,11 +6,11 @@ from pydantic import PositiveInt from sklearn.model_selection import KFold, StratifiedKFold -from clinicadl.dataset.utils import tsv_to_df +from clinicadl.data.utils import tsv_to_df from clinicadl.splitter.make_splits.utils import write_to_csv from clinicadl.splitter.splitter.kfold import KFoldConfig from clinicadl.tsvtools.tsvtools_utils import extract_baseline -from clinicadl.utils.exceptions import ClinicaDLConfigurationError +from clinicadl.utils.exceptions import ClinicaDLConfigurationError, ClinicaDLTSVError def _validate_stratification( @@ -98,7 +98,7 @@ def preprocess_stratification( def make_kfold( - tsv_path: Path, + data: Union[pd.DataFrame, Path, str], output_dir: Optional[Union[Path, str]] = None, subset_name: str = "validation", valid_longitudinal: bool = False, @@ -110,10 +110,10 @@ def make_kfold( Parameters ---------- - tsv_path : Path - Path to the input TSV file. - output_dir : Optional[Path] - Directory to save the split files. Defaults to the parent directory of `tsv_path`. + data: Union[pd.DataFrame, Path, str], + Path to the TSV file or a DataFrame containing participant/session pairs. + output_dir : Optional[Path, str] + Directory to save the split files. subset_name : str, default="validation" Name of the subset used for output files. valid_longitudinal : bool, default=False @@ -134,9 +134,25 @@ def make_kfold( If invalid configuration options are provided. """ - # Set default output directory - output_dir = output_dir or tsv_path.parent + if isinstance(data, str) or isinstance(data, Path): + data = Path(data) + + # Set default output directory + output_dir = output_dir or data.parent + # Load dataset and preprocess + df = tsv_to_df(data) + + elif isinstance(data, pd.DataFrame): + if not output_dir: + raise ValueError("You must specify the output directory.") + + if data.empty: + raise ClinicaDLTSVError(f"The input data is empty: {data}") + else: + df = data + output_dir = Path(output_dir) + baseline_df = extract_baseline(df) # Initialize KFold configuration config = KFoldConfig( @@ -150,10 +166,6 @@ def make_kfold( config._check_split_dir() config._write_json() - # Load and process dataset - df = tsv_to_df(tsv_path) - baseline_df = extract_baseline(df) - stratify_labels = preprocess_stratification( df=baseline_df, stratification=config.stratification, diff --git a/clinicadl/splitter/make_splits/single_split.py b/clinicadl/splitter/make_splits/single_split.py index 1854b5449..21c15af3d 100644 --- a/clinicadl/splitter/make_splits/single_split.py +++ b/clinicadl/splitter/make_splits/single_split.py @@ -8,7 +8,7 @@ from scipy.stats import chisquare, ks_2samp, ttest_ind from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit -from clinicadl.dataset.utils import tsv_to_df +from clinicadl.data.utils import tsv_to_df from clinicadl.splitter.make_splits.utils import write_to_csv from clinicadl.splitter.splitter.single_split import SingleSplitConfig from clinicadl.tsvtools.tsvtools_utils import extract_baseline @@ -128,7 +128,7 @@ def _chi2_test(x_test: List[int], x_train: List[int]) -> float: def make_split( - tsv_path: Path, + data: Union[pd.DataFrame, Path, str], output_dir: Optional[Union[Path, str]] = None, n_test: PositiveFloat = 100, subset_name: str = "test", @@ -143,9 +143,9 @@ def make_split( Parameters ---------- - tsv_path : Path - Path to the input TSV file. - output_dir : Optional[Path] + data: Union[pd.DataFrame, Path, str], + Path to the TSV file or a DataFrame containing participant/session pairs. + output_dir : Optional[Path, str] Directory to save the split files. n_test : PositiveFloat If >= 1, specifies the absolute number of test samples. If < 1, treated as a proportion of the dataset. @@ -167,13 +167,24 @@ def make_split( Path Directory containing the split files. """ + if isinstance(data, str) or isinstance(data, Path): + data = Path(data) - # Set default output directory - output_dir = output_dir or tsv_path.parent - output_dir = Path(output_dir) + # Set default output directory + output_dir = output_dir or data.parent + # Load dataset and preprocess + df = tsv_to_df(data) + + elif isinstance(data, pd.DataFrame): + if not output_dir: + raise ValueError("You must specify the output directory.") - # Load dataset and preprocess - df = tsv_to_df(tsv_path) + if data.empty: + raise ClinicaDLTSVError(f"The input data is empty: {data}") + else: + df = data + + output_dir = Path(output_dir) baseline_df = extract_baseline(df) n_test = int(n_test) if n_test >= 1 else int(n_test * len(baseline_df)) diff --git a/clinicadl/splitter/split.py b/clinicadl/splitter/split.py index ff9ec21f8..761595685 100644 --- a/clinicadl/splitter/split.py +++ b/clinicadl/splitter/split.py @@ -4,8 +4,8 @@ from pydantic import NonNegativeInt from torch.utils.data import DataLoader -from clinicadl.dataset.dataloader import DataLoaderConfig -from clinicadl.dataset.dataloader.defaults import ( +from clinicadl.data.dataloader import DataLoaderConfig +from clinicadl.data.dataloader.defaults import ( BATCH_SIZE, DP_DEGREE, DROP_LAST, @@ -16,7 +16,7 @@ SAMPLING_WEIGHTS, SHUFFLE, ) -from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.caps_dataset import CapsDataset from clinicadl.utils.config import ClinicaDLConfig diff --git a/clinicadl/splitter/splitter/kfold.py b/clinicadl/splitter/splitter/kfold.py index d4a306508..c100c3818 100644 --- a/clinicadl/splitter/splitter/kfold.py +++ b/clinicadl/splitter/splitter/kfold.py @@ -3,7 +3,7 @@ from pydantic import PositiveInt -from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.caps_dataset import CapsDataset from clinicadl.splitter.split import Split from clinicadl.splitter.splitter.splitter import ( Splitter, diff --git a/clinicadl/splitter/splitter/single_split.py b/clinicadl/splitter/splitter/single_split.py index daee8fd9f..96eb682d9 100644 --- a/clinicadl/splitter/splitter/single_split.py +++ b/clinicadl/splitter/splitter/single_split.py @@ -3,7 +3,7 @@ from pydantic import PositiveInt, field_validator -from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.caps_dataset import CapsDataset from clinicadl.splitter.split import Split from clinicadl.splitter.splitter.splitter import ( Splitter, diff --git a/clinicadl/splitter/splitter/splitter.py b/clinicadl/splitter/splitter/splitter.py index ae7f65ca0..20b28b130 100644 --- a/clinicadl/splitter/splitter/splitter.py +++ b/clinicadl/splitter/splitter/splitter.py @@ -9,7 +9,7 @@ field_validator, ) -from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.data.datasets.caps_dataset import CapsDataset from clinicadl.splitter.split import Split from clinicadl.utils.config import ClinicaDLConfig from clinicadl.utils.exceptions import ClinicaDLTSVError diff --git a/clinicadl/splitter/test.py b/clinicadl/splitter/test.py index 83010184e..fb3c2fda0 100644 --- a/clinicadl/splitter/test.py +++ b/clinicadl/splitter/test.py @@ -3,8 +3,8 @@ import pandas as pd import torchio.transforms as transforms -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.preprocessing import PreprocessingT1 +from clinicadl.data.datasets.caps_dataset import CapsDataset +from clinicadl.data.preprocessing import PreprocessingT1 from clinicadl.splitter import make_kfold, make_split from clinicadl.splitter.dataloader import DataLoaderConfig from clinicadl.splitter.splitter import KFold, SingleSplit diff --git a/clinicadl/tmp_config.py b/clinicadl/tmp_config.py index 54a791b1e..6a52b4154 100644 --- a/clinicadl/tmp_config.py +++ b/clinicadl/tmp_config.py @@ -19,7 +19,7 @@ ) from typing_extensions import Self -from clinicadl.dataset.caps_dataset import return_dataset +from clinicadl.data.caps_dataset import return_dataset from clinicadl.metrics.old_metrics.metric_module import MetricModule from clinicadl.splitter.split_utils import find_splits from clinicadl.trainer.tasks_utils import ( @@ -380,7 +380,7 @@ def check_preprocessing_dict(self) -> Self: ValueError In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. """ - from clinicadl.dataset.caps_dataset import CapsDataset + from clinicadl.data.caps_dataset import CapsDataset if self.preprocessing_dict is None: if self.preprocessing_json is not None: diff --git a/clinicadl/trainer/__init__.py b/clinicadl/trainer/__init__.py index e69de29bb..260e4c8d6 100644 --- a/clinicadl/trainer/__init__.py +++ b/clinicadl/trainer/__init__.py @@ -0,0 +1 @@ +from .trainer import Trainer diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 64a4f1c29..ae9266c8a 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -3,7 +3,7 @@ from pydantic import computed_field, field_validator -from clinicadl.dataset.config.data import DataConfig as BaseDataConfig +from clinicadl.data.config.data import DataConfig as BaseDataConfig from clinicadl.networks.old_network.config import NetworkConfig as BaseNetworkConfig from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index e29bf29e0..78c17b3c3 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -3,7 +3,7 @@ from pydantic import computed_field, field_validator -from clinicadl.dataset.config.data import DataConfig as BaseDataConfig +from clinicadl.data.config.data import DataConfig as BaseDataConfig from clinicadl.networks.old_network.config import NetworkConfig as BaseNetworkConfig from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index 538eccecd..2bd83b9d4 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -12,7 +12,7 @@ from clinicadl.callbacks.config import CallbacksConfig from clinicadl.config.config.lr_scheduler import LRschedulerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig -from clinicadl.dataset.config.data import DataConfig +from clinicadl.data.config.data import DataConfig from clinicadl.experiment_manager.config import MapsManagerConfig from clinicadl.networks.old_network.config import NetworkConfig from clinicadl.optim.config import OptimizationConfig diff --git a/clinicadl/trainer/old_trainer.py b/clinicadl/trainer/old_trainer.py index aa3f82c70..dd3613c15 100644 --- a/clinicadl/trainer/old_trainer.py +++ b/clinicadl/trainer/old_trainer.py @@ -16,13 +16,13 @@ from torch.utils.data.distributed import DistributedSampler from clinicadl.splitter.split_utils import find_finished_splits, find_stopped_splits -from clinicadl.dataset.caps_dataset import return_dataset +from clinicadl.data.caps_dataset import return_dataset from clinicadl.optim.early_stopping import EarlyStopping from clinicadl.utils.exceptions import MAPSError from clinicadl.utils.computational.ddp import DDP from clinicadl.utils import cluster from clinicadl.utils.logwriter import LogWriter -from clinicadl.dataset.caps_dataset_utils import read_json +from clinicadl.data.caps_dataset_utils import read_json from clinicadl.metrics.old_metrics.metric_module import RetainBest from clinicadl.utils.seed import pl_worker_init_function, seed_everything from clinicadl.experiment_manager.maps_manager import MapsManager diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index a14bfa4a9..920b344d8 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader, Sampler, sampler from torch.utils.data.distributed import DistributedSampler -from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.data.caps_dataset import CapsDataset from clinicadl.metrics.old_metrics.metric_module import MetricModule from clinicadl.networks.old_network.network import Network from clinicadl.trainer.config.train import TrainConfig diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 386a7d148..77a1cdcc3 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -1,20 +1,231 @@ +from __future__ import annotations + +from contextlib import nullcontext +from logging import getLogger from pathlib import Path +from typing import Optional + +import torch +import torch.distributed as dist +from torch.amp.autocast_mode import autocast +from torch.amp.grad_scaler import GradScaler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.data.datasets.caps_dataset import CapsDataset from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.metrics.old_metrics.metric_module import RetainBest from clinicadl.model.clinicadl_model import ClinicaDLModel -from clinicadl.splitter.kfold import Split +from clinicadl.optim.early_stopping import EarlyStopping +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.split import Split +from clinicadl.trainer.tasks_utils import get_criterion +from clinicadl.utils import cluster +from clinicadl.utils.logwriter import LogWriter + +logger = getLogger("clinicadl.trainer") + + +class MapsReader: + maps_path: Path + + def get_model(self) -> ClinicaDLModel: + return ClinicaDLModel() + + def _write_network_weights(self): + pass + + def _write_optim_weights(self): + pass + + def write_tensor(self): + pass + + def optimizer_path(self, split: int, resume: bool = False) -> Path: + """TO COMPLETE""" + + checkpoint_path = ( + self.maps_path / f"split-{split}" / "tmp" / "optimizer.pth.tar" + ) + return checkpoint_path + + def checkpoint_path(self, split: int, resume: bool = False): + checkpoint_path = ( + self.maps_path / f"split-{split}" / "tmp" / "checkpoint.pth.tar" + ) + return checkpoint_path class Trainer: - def __init__(self) -> None: + def __init__(self, maps_path: Path) -> None: """TO COMPLETE""" + self.reader = MapsReader(maps_path) + self.maps_path = maps_path @classmethod def from_json(cls, config_file: Path, manager: ExperimentManager) -> Trainer: """TO COMPLETE""" - return cls() + return Trainer() - def train(self, model: ClinicaDLModel, split: Split): + @classmethod + def from_maps(cls, maps_path: str | Path) -> Trainer: """TO COMPLETE""" + return Trainer() + + def _init_profiler(self): pass + + def resume(self, split: Split): + """TO COMPLETE""" + + model = self.reader.get_model() + + model.load_optim_state_dict( + self.reader.optimizer_path(split.index, resume=True) + ) + current_epoch = model.load_state_dict( + self.reader.checkpoint_path(split.index, resume=True) + ) + + def train(self, model: ClinicaDLModel, split: Split, epoch: int = 0): + """TO COMPLETE""" + + # NEEDED ARG + adaptive_learning_rate: bool = False + amp: bool = False + n_epochs: int = 30 + accumulation_steps: int = 3 + evaluation_steps: int = 4 + save_outputs: bool = ( + False # depend on the network task, only ok for reconstruction + ) + network_task: str = "classification" # TASK enum + # + + # INIT + criterion = get_criterion(network_task, model.loss) + early_stopping = EarlyStopping() + metrics_valid = {"loss": None} + retain_best = RetainBest() + scaler = GradScaler("cuda", enabled=amp) + profiler = self._init_profiler() + + if cluster.master: + log_writer = LogWriter() + + model.network.train() + split.train_loader.dataset.train() + + if adaptive_learning_rate: + from torch.optim.lr_scheduler import ReduceLROnPlateau + + scheduler = ReduceLROnPlateau(model.optimizer, mode="min", factor=0.1) + + validator = Predictor() + # + + while epoch < n_epochs and not early_stopping.step(metrics_valid["loss"]): + if isinstance(split.train_loader.sampler, DistributedSampler): + # It should always be true for a random sampler. But just in case + # we get a WeightedRandomSampler or a forgotten RandomSampler, + # we do not want to execute this line. + split.train_loader.sampler.set_epoch(epoch) + + model.network.zero_grad(set_to_none=True) + evaluation_flag, step_flag = True, True + + with profiler: + for i, data in enumerate(split.train_loader): + update: bool = (i + 1) % accumulation_steps == 0 + sync = nullcontext() if update else model.network.no_sync() + with sync: + with autocast("cuda", enabled=amp): + _, loss_dict = model.network(data, criterion) + + loss = loss_dict["loss"] + scaler.scale(loss).backward() + + if update: + step_flag = False + scaler.step(model.optimizer) + scaler.update() + model.optimizer.zero_grad(set_to_none=True) + + del loss + + # Evaluate the model only when no gradients are accumulated + if evaluation_steps != 0 and (i + 1) % evaluation_steps == 0: + evaluation_flag = False + + _, metrics_train = validator.test( + dataloader=split.train_loader + ) + _, metrics_valid = validator.test( + dataloader=split.val_loader + ) + + model.network.train() + split.train_loader.dataset.train() + + if cluster.master: + log_writer.step( + epoch, + i, + metrics_train, + metrics_valid, + len(split.train_loader), + ) + + profiler.step() + + # If no step has been performed, raise Exception + if step_flag: + raise ValueError( + "The model has not been updated once in the epoch. The accumulation step may be too large." + ) + + # If no evaluation has been performed, warn the user + elif evaluation_flag and evaluation_steps != 0: + logger.warning( + f"Your evaluation steps {evaluation_steps} are too big " + f"compared to the size of the dataset. " + f"The model is evaluated only once at the end epochs." + ) + + # Update weights one last time if gradients were computed without update + if (i + 1) % accumulation_steps != 0: + scaler.step(model.optimizer) + scaler.update() + model.optimizer.zero_grad(set_to_none=True) + + # Always test the results and save them once at the end of the epoch + model.network.zero_grad(set_to_none=True) + logger.debug(f"Last checkpoint at the end of the epoch {epoch}") + + _, metrics_train = validator.test(dataloader=split.train_loader) + _, metrics_valid = validator.test(dataloader=split.val_loader) + + model.network.train() + split.train_loader.dataset.train() + + if cluster.master: + # Save checkpoints and best models + best_dict = retain_best.step(metrics_valid) + self.reader._write_optim_weights(best_dict) + self.reader._write_network_weights(best_dict) + + dist.barrier() + + if adaptive_learning_rate: + scheduler.step( + metrics_valid["loss"] + ) # Update learning rate based on validation loss + + epoch += 1 + + del model + validator._test_loader(dataloader=split.train_loader) + validator._test_loader(datalaoder=split.val_loader) + + if save_outputs: + self.reader.write_tensor() diff --git a/clinicadl/config/config_utils.py b/clinicadl/utils/config_utils.py similarity index 100% rename from clinicadl/config/config_utils.py rename to clinicadl/utils/config_utils.py diff --git a/clinicadl/utils/iotools/train_utils.py b/clinicadl/utils/iotools/train_utils.py index 21fb160d5..b0942ea62 100644 --- a/clinicadl/utils/iotools/train_utils.py +++ b/clinicadl/utils/iotools/train_utils.py @@ -220,7 +220,7 @@ def merge_cli_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Any] Dict[str, Any] A dictionary with training options. """ - from clinicadl.dataset.caps_dataset_utils import read_json + from clinicadl.data.caps_dataset_utils import read_json options = read_json(maps_json) for arg in kwargs: @@ -253,7 +253,7 @@ def merge_options_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Dict[str, Any] A dictionary with training options. """ - from clinicadl.dataset.caps_dataset_utils import read_json + from clinicadl.data.caps_dataset_utils import read_json options = read_json(maps_json) for arg in kwargs: diff --git a/tests/unittests/dataset/test_config.py b/tests/unittests/dataset/test_config.py index 4fa9588e9..56378c8a9 100644 --- a/tests/unittests/dataset/test_config.py +++ b/tests/unittests/dataset/test_config.py @@ -1,6 +1,6 @@ import pytest -from clinicadl.dataset.config import FileType +from clinicadl.data.config import FileType from clinicadl.utils.enum import Preprocessing diff --git a/tests/unittests/dataset/test_datasets.py b/tests/unittests/dataset/test_datasets.py index 85de78228..e57f2351c 100644 --- a/tests/unittests/dataset/test_datasets.py +++ b/tests/unittests/dataset/test_datasets.py @@ -4,8 +4,8 @@ import numpy as np import pytest -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.preprocessing import PreprocessingT1, PreprocessingT2 +from clinicadl.data.datasets.caps_dataset import CapsDataset +from clinicadl.data.preprocessing import PreprocessingT1, PreprocessingT2 from clinicadl.transforms import Transforms from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ( diff --git a/tests/unittests/dataset/test_reader.py b/tests/unittests/dataset/test_reader.py index 20cf4dd24..a01b04221 100644 --- a/tests/unittests/dataset/test_reader.py +++ b/tests/unittests/dataset/test_reader.py @@ -4,8 +4,8 @@ import numpy as np import pytest -from clinicadl.dataset.preprocessing import PreprocessingT1, PreprocessingT2 -from clinicadl.dataset.readers import CapsReader +from clinicadl.data.preprocessing import PreprocessingT1, PreprocessingT2 +from clinicadl.data.readers import CapsReader from clinicadl.transforms import Transforms from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ( diff --git a/tests/unittests/random_search/test_random_search_config.py b/tests/unittests/random_search/test_random_search_config.py index f2e195309..c4b109342 100644 --- a/tests/unittests/random_search/test_random_search_config.py +++ b/tests/unittests/random_search/test_random_search_config.py @@ -1,73 +1,73 @@ -from pathlib import Path +# from pathlib import Path -import pytest -from pydantic import ValidationError +# import pytest +# from pydantic import ValidationError -# Test RandomSearchConfig # -def test_random_search_config(): - from clinicadl.random_search.random_search_config import RandomSearchConfig +# # Test RandomSearchConfig # +# def test_random_search_config(): +# from clinicadl.random_search.random_search_config import RandomSearchConfig - config = RandomSearchConfig( - first_conv_width=[1, 2], - n_convblocks=1, - n_fcblocks=(1,), - ) - assert config.first_conv_width == (1, 2) - assert config.n_convblocks == (1,) - assert config.n_fcblocks == (1,) - with pytest.raises(ValidationError): - config.first_conv_width = (1, 0) +# config = RandomSearchConfig( +# first_conv_width=[1, 2], +# n_convblocks=1, +# n_fcblocks=(1,), +# ) +# assert config.first_conv_width == (1, 2) +# assert config.n_convblocks == (1,) +# assert config.n_fcblocks == (1,) +# with pytest.raises(ValidationError): +# config.first_conv_width = (1, 0) -# Test Training Configs # -@pytest.fixture -def caps_example(): - dir_ = Path(__file__).parents[1] / "ressources" / "caps_example" - return dir_ +# # Test Training Configs # +# @pytest.fixture +# def caps_example(): +# dir_ = Path(__file__).parents[1] / "ressources" / "caps_example" +# return dir_ -@pytest.fixture -def dummy_arguments(caps_example): - args = { - "caps_directory": caps_example, - "preprocessing_json": "preprocessing.json", - "tsv_path": "", - "maps_dir": "", - "gpu": False, - } - return args +# @pytest.fixture +# def dummy_arguments(caps_example): +# args = { +# "caps_directory": caps_example, +# "preprocessing_json": "preprocessing.json", +# "tsv_path": "", +# "maps_dir": "", +# "gpu": False, +# } +# return args -@pytest.fixture -def random_model_arguments(): - args = { - "convolutions_dict": { - "conv0": { - "in_channels": 1, - "out_channels": 8, - "n_conv": 2, - "d_reduction": "MaxPooling", - }, - "conv1": { - "in_channels": 8, - "out_channels": 16, - "n_conv": 3, - "d_reduction": "MaxPooling", - }, - }, - "n_fcblocks": 2, - } - return args +# @pytest.fixture +# def random_model_arguments(): +# args = { +# "convolutions_dict": { +# "conv0": { +# "in_channels": 1, +# "out_channels": 8, +# "n_conv": 2, +# "d_reduction": "MaxPooling", +# }, +# "conv1": { +# "in_channels": 8, +# "out_channels": 16, +# "n_conv": 3, +# "d_reduction": "MaxPooling", +# }, +# }, +# "n_fcblocks": 2, +# } +# return args -def test_training_config(dummy_arguments, random_model_arguments): - from clinicadl.random_search.random_search_config import ClassificationConfig +# def test_training_config(dummy_arguments, random_model_arguments): +# from clinicadl.random_search.random_search_config import ClassificationConfig - config = ClassificationConfig(**dummy_arguments, **random_model_arguments) - assert config.model.convolutions_dict == random_model_arguments["convolutions_dict"] - assert config.model.n_fcblocks == random_model_arguments["n_fcblocks"] - assert config.model.architecture == "RandomArchitecture" - assert config.network_task == "classification" - with pytest.raises(ValidationError): - config.model.architecture = "abc" +# config = ClassificationConfig(**dummy_arguments, **random_model_arguments) +# assert config.model.convolutions_dict == random_model_arguments["convolutions_dict"] +# assert config.model.n_fcblocks == random_model_arguments["n_fcblocks"] +# assert config.model.architecture == "RandomArchitecture" +# assert config.network_task == "classification" +# with pytest.raises(ValidationError): +# config.model.architecture = "abc" diff --git a/tests/unittests/splitter/test_make_split.py b/tests/unittests/splitter/test_make_split.py index 3216f097f..27af875fe 100644 --- a/tests/unittests/splitter/test_make_split.py +++ b/tests/unittests/splitter/test_make_split.py @@ -6,7 +6,11 @@ import pytest from pydantic import ValidationError +from clinicadl.data.datasets import CapsDataset +from clinicadl.data.preprocessing import PreprocessingT1 from clinicadl.splitter.make_splits import make_kfold, make_split +from clinicadl.transforms import Transforms +from clinicadl.tsvtools.tsvtools_utils import extract_baseline from clinicadl.utils.exceptions import ( ClinicaDLConfigurationError, ClinicaDLTSVError, @@ -96,6 +100,23 @@ def test_good_split(): remove_non_empty_dir(split_dir_bis_bis) +def test_make_split_and_kfold_from_df(): + dataset = CapsDataset(caps_dir, PreprocessingT1(), Transforms()) + with pytest.raises(ValueError): + _ = make_split(dataset.df, n_test=0.2) + + split_dir = make_split(dataset.df, output_dir=Path("from_df"), n_test=1) + train_path = split_dir / "train_baseline.tsv" + test_path = split_dir / "test_baseline.tsv" + assert train_path.exists() + assert test_path.exists() + + train_df = pd.read_csv(train_path, sep="\t") + test_df = pd.read_csv(test_path, sep="\t") + assert len(test_df) + len(train_df) == len(extract_baseline(dataset.df)) + remove_non_empty_dir(split_dir) + + def test_bad_split(): with pytest.raises(ClinicaDLTSVError): make_split(caps_dir / "test.tsv", n_test=15) diff --git a/tests/unittests/splitter/test_splitter.py b/tests/unittests/splitter/test_splitter.py index 8cf260bfd..a1c9b0399 100644 --- a/tests/unittests/splitter/test_splitter.py +++ b/tests/unittests/splitter/test_splitter.py @@ -7,8 +7,8 @@ import pytest from pydantic import ValidationError -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.preprocessing import PreprocessingT1, PreprocessingT2 +from clinicadl.data.datasets.caps_dataset import CapsDataset +from clinicadl.data.preprocessing import PreprocessingT1, PreprocessingT2 from clinicadl.splitter.split import Split from clinicadl.splitter.splitter.kfold import KFold, KFoldConfig from clinicadl.splitter.splitter.single_split import SingleSplit, SingleSplitConfig diff --git a/tests/unittests/train/tasks/classification/test_classification_config.py b/tests/unittests/train/tasks/classification/test_classification_config.py index 71e853872..5683109cd 100644 --- a/tests/unittests/train/tasks/classification/test_classification_config.py +++ b/tests/unittests/train/tasks/classification/test_classification_config.py @@ -1,86 +1,86 @@ -from pathlib import Path - -import pytest -from pydantic import ValidationError - -import clinicadl.trainer.config.classification as classification - - -# Tests for customed class methods # -def test_model_config(): - with pytest.raises(ValidationError): - classification.NetworkConfig( - **{ - "architecture": "", - "loss": "", - "selection_threshold": 1.1, - } - ) - - -def test_validation_config(): - c = classification.ValidationConfig(selection_metrics=["accuracy"]) - assert c.selection_metrics == ("accuracy",) - - -# Global tests # -@pytest.fixture -def caps_example(): - dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" - return dir_ - - -@pytest.fixture -def dummy_arguments(caps_example): - args = { - "caps_directory": caps_example, - "preprocessing_json": "preprocessing.json", - "tsv_path": "", - "maps_dir": "", - "gpu": False, - } - return args - - -@pytest.fixture( - params=[ - {"loss": "abc"}, - {"selection_metrics": ("abc",)}, - {"selection_metrics": "F1_score"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture -def good_inputs(dummy_arguments): - options = { - "loss": "MultiMarginLoss", - "selection_metrics": ("F1_score",), - "selection_threshold": 0.5, - } - return {**dummy_arguments, **options} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - classification.ClassificationConfig(**bad_inputs) - - -def test_passes_validations(good_inputs): - c = classification.ClassificationConfig(**good_inputs) - assert c.model.loss == "MultiMarginLoss" - assert c.validation.selection_metrics == ("F1_score",) - assert c.model.selection_threshold == 0.5 - assert c.network_task == "classification" - - -def test_update_from_toml(dummy_arguments): - toml_path = ( - Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" - ) - c = classification.ClassificationConfig(**dummy_arguments) - c.update_with_toml(toml_path) - assert not c.computational.gpu - assert c.model.loss == "MultiMarginLoss" +# from pathlib import Path + +# import pytest +# from pydantic import ValidationError + +# import clinicadl.trainer.config.classification as classification + + +# # Tests for customed class methods # +# def test_model_config(): +# with pytest.raises(ValidationError): +# classification.NetworkConfig( +# **{ +# "architecture": "", +# "loss": "", +# "selection_threshold": 1.1, +# } +# ) + + +# def test_validation_config(): +# c = classification.ValidationConfig(selection_metrics=["accuracy"]) +# assert c.selection_metrics == ("accuracy",) + + +# # Global tests # +# @pytest.fixture +# def caps_example(): +# dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" +# return dir_ + + +# @pytest.fixture +# def dummy_arguments(caps_example): +# args = { +# "caps_directory": caps_example, +# "preprocessing_json": "preprocessing.json", +# "tsv_path": "", +# "maps_dir": "", +# "gpu": False, +# } +# return args + + +# @pytest.fixture( +# params=[ +# {"loss": "abc"}, +# {"selection_metrics": ("abc",)}, +# {"selection_metrics": "F1_score"}, +# ] +# ) +# def bad_inputs(request, dummy_arguments): +# return {**dummy_arguments, **request.param} + + +# @pytest.fixture +# def good_inputs(dummy_arguments): +# options = { +# "loss": "MultiMarginLoss", +# "selection_metrics": ("F1_score",), +# "selection_threshold": 0.5, +# } +# return {**dummy_arguments, **options} + + +# def test_fails_validations(bad_inputs): +# with pytest.raises(ValidationError): +# classification.ClassificationConfig(**bad_inputs) + + +# def test_passes_validations(good_inputs): +# c = classification.ClassificationConfig(**good_inputs) +# assert c.model.loss == "MultiMarginLoss" +# assert c.validation.selection_metrics == ("F1_score",) +# assert c.model.selection_threshold == 0.5 +# assert c.network_task == "classification" + + +# def test_update_from_toml(dummy_arguments): +# toml_path = ( +# Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" +# ) +# c = classification.ClassificationConfig(**dummy_arguments) +# c.update_with_toml(toml_path) +# assert not c.computational.gpu +# assert c.model.loss == "MultiMarginLoss" diff --git a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py index f013386c1..a27ceda3a 100644 --- a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py +++ b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py @@ -1,75 +1,75 @@ -from pathlib import Path +# from pathlib import Path -import pytest -from pydantic import ValidationError +# import pytest +# from pydantic import ValidationError -import clinicadl.trainer.config.reconstruction as reconstruction +# import clinicadl.trainer.config.reconstruction as reconstruction -# Tests for customed validators # -def test_validation_config(): - c = reconstruction.ValidationConfig(selection_metrics=["MAE"]) - assert c.selection_metrics == ("MAE",) +# # Tests for customed validators # +# def test_validation_config(): +# c = reconstruction.ValidationConfig(selection_metrics=["MAE"]) +# assert c.selection_metrics == ("MAE",) -# Global tests on the TrainingConfig class # -@pytest.fixture -def caps_example(): - dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" - return dir_ +# # Global tests on the TrainingConfig class # +# @pytest.fixture +# def caps_example(): +# dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" +# return dir_ -@pytest.fixture -def dummy_arguments(caps_example): - args = { - "caps_directory": caps_example, - "preprocessing_json": "preprocessing.json", - "tsv_path": "", - "maps_dir": "", - "gpu": False, - } - return args +# @pytest.fixture +# def dummy_arguments(caps_example): +# args = { +# "caps_directory": caps_example, +# "preprocessing_json": "preprocessing.json", +# "tsv_path": "", +# "maps_dir": "", +# "gpu": False, +# } +# return args -@pytest.fixture( - params=[ - {"loss": "abc"}, - {"selection_metrics": ("abc",)}, - {"normalization": "abc"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} +# @pytest.fixture( +# params=[ +# {"loss": "abc"}, +# {"selection_metrics": ("abc",)}, +# {"normalization": "abc"}, +# ] +# ) +# def bad_inputs(request, dummy_arguments): +# return {**dummy_arguments, **request.param} -@pytest.fixture -def good_inputs(dummy_arguments): - options = { - "loss": "HuberLoss", - "selection_metrics": ("PSNR",), - "normalization": "BatchNorm", - } - return {**dummy_arguments, **options} +# @pytest.fixture +# def good_inputs(dummy_arguments): +# options = { +# "loss": "HuberLoss", +# "selection_metrics": ("PSNR",), +# "normalization": "BatchNorm", +# } +# return {**dummy_arguments, **options} -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - reconstruction.ReconstructionConfig(**bad_inputs) +# def test_fails_validations(bad_inputs): +# with pytest.raises(ValidationError): +# reconstruction.ReconstructionConfig(**bad_inputs) -def test_passes_validations(good_inputs): - c = reconstruction.ReconstructionConfig(**good_inputs) - assert c.model.loss == "HuberLoss" - assert c.validation.selection_metrics == ("PSNR",) - assert c.model.normalization == "BatchNorm" - assert c.network_task == "reconstruction" +# def test_passes_validations(good_inputs): +# c = reconstruction.ReconstructionConfig(**good_inputs) +# assert c.model.loss == "HuberLoss" +# assert c.validation.selection_metrics == ("PSNR",) +# assert c.model.normalization == "BatchNorm" +# assert c.network_task == "reconstruction" -def test_update_from_toml(dummy_arguments): - toml_path = ( - Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" - ) - c = reconstruction.ReconstructionConfig(**dummy_arguments) - c.update_with_toml(toml_path) - assert not c.computational.gpu - assert c.model.loss == "VAEBernoulliLoss" +# def test_update_from_toml(dummy_arguments): +# toml_path = ( +# Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" +# ) +# c = reconstruction.ReconstructionConfig(**dummy_arguments) +# c.update_with_toml(toml_path) +# assert not c.computational.gpu +# assert c.model.loss == "VAEBernoulliLoss" diff --git a/tests/unittests/train/tasks/regression/test_regression_config.py b/tests/unittests/train/tasks/regression/test_regression_config.py index 4b01e1084..d854f17bd 100644 --- a/tests/unittests/train/tasks/regression/test_regression_config.py +++ b/tests/unittests/train/tasks/regression/test_regression_config.py @@ -1,73 +1,73 @@ -from pathlib import Path +# from pathlib import Path -import pytest -from pydantic import ValidationError +# import pytest +# from pydantic import ValidationError -import clinicadl.trainer.config.regression as regression +# import clinicadl.trainer.config.regression as regression -# Tests for customed validators # -def test_validation_config(): - c = regression.ValidationConfig(selection_metrics=["R2_score"]) - assert c.selection_metrics == ("R2_score",) +# # Tests for customed validators # +# def test_validation_config(): +# c = regression.ValidationConfig(selection_metrics=["R2_score"]) +# assert c.selection_metrics == ("R2_score",) -# Global tests on the TrainingConfig class # -@pytest.fixture -def caps_example(): - dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" - return dir_ +# # Global tests on the TrainingConfig class # +# @pytest.fixture +# def caps_example(): +# dir_ = Path(__file__).parents[3] / "ressources" / "caps_example" +# return dir_ -@pytest.fixture -def dummy_arguments(caps_example): - args = { - "caps_directory": caps_example, - "preprocessing_json": "preprocessing.json", - "tsv_path": "", - "maps_dir": "", - "gpu": False, - } - return args +# @pytest.fixture +# def dummy_arguments(caps_example): +# args = { +# "caps_directory": caps_example, +# "preprocessing_json": "preprocessing.json", +# "tsv_path": "", +# "maps_dir": "", +# "gpu": False, +# } +# return args -@pytest.fixture( - params=[ - {"loss": "abc"}, - {"selection_metrics": ("abc",)}, - {"selection_metrics": "R2_score"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} +# @pytest.fixture( +# params=[ +# {"loss": "abc"}, +# {"selection_metrics": ("abc",)}, +# {"selection_metrics": "R2_score"}, +# ] +# ) +# def bad_inputs(request, dummy_arguments): +# return {**dummy_arguments, **request.param} -@pytest.fixture -def good_inputs(dummy_arguments): - options = { - "loss": "KLDivLoss", - "selection_metrics": ("R2_score",), - } - return {**dummy_arguments, **options} +# @pytest.fixture +# def good_inputs(dummy_arguments): +# options = { +# "loss": "KLDivLoss", +# "selection_metrics": ("R2_score",), +# } +# return {**dummy_arguments, **options} -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - regression.RegressionConfig(**bad_inputs) +# def test_fails_validations(bad_inputs): +# with pytest.raises(ValidationError): +# regression.RegressionConfig(**bad_inputs) -def test_passes_validations(good_inputs): - c = regression.RegressionConfig(**good_inputs) - assert c.model.loss == "KLDivLoss" - assert c.validation.selection_metrics == ("R2_score",) - assert c.network_task == "regression" +# def test_passes_validations(good_inputs): +# c = regression.RegressionConfig(**good_inputs) +# assert c.model.loss == "KLDivLoss" +# assert c.validation.selection_metrics == ("R2_score",) +# assert c.network_task == "regression" -def test_update_from_toml(dummy_arguments): - toml_path = ( - Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" - ) - c = regression.RegressionConfig(**dummy_arguments) - c.update_with_toml(toml_path) - assert not c.computational.gpu - assert c.model.loss == "SmoothL1Loss" +# def test_update_from_toml(dummy_arguments): +# toml_path = ( +# Path(__file__).parents[3] / "ressources" / "functional_config_example.toml" +# ) +# c = regression.RegressionConfig(**dummy_arguments) +# c.update_with_toml(toml_path) +# assert not c.computational.gpu +# assert c.model.loss == "SmoothL1Loss" diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 07b07fd8f..589f20453 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -1,165 +1,165 @@ -from pathlib import Path - -import pytest -from pydantic import ValidationError - -from clinicadl.dataset.config.data import DataConfig -from clinicadl.networks.old_network.config import NetworkConfig -from clinicadl.predictor.validation import ValidationConfig -from clinicadl.trainer.transfer_learning import TransferLearningConfig -from clinicadl.transforms.config import TransformsConfig - - -# Tests for customed validators # -@pytest.fixture -def caps_example(): - dir_ = Path(__file__).parents[2] / "ressources" / "caps_example" - return dir_ - - -# def test_cross_validation_config(): -# c = ValidationConfig( -# split=[0], -# tsv_path="", -# ) -# assert c.split == (0,) - - -# def test_data_config(caps_example): -# c = DataConfig( -# caps_directory=caps_example, -# preprocessing_json="preprocessing.json", -# diagnoses=["AD"], -# ) -# expected_preprocessing_dict = { -# "preprocessing": "t1-linear", -# "mode": "image", -# "use_uncropped_image": False, -# "prepare_dl": False, -# "extract_json": "t1-linear_mode-image.json", -# "file_type": { -# "pattern": "*space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz", -# "description": "T1W Image registered using t1-linear and cropped (matrix size 169\u00d7208\u00d7179, 1 mm isotropic voxels)", -# "needed_pipeline": "t1-linear", -# }, -# } -# # assert c.diagnoses == ("AD",) -# assert ( -# c.preprocessing_dict == expected_preprocessing_dict -# ) # TODO : add test for multi-cohort -# assert c.mode == "image" -# # with pytest.raises(ValidationError): -# # c.preprocessing_dict = {"abc": "abc"} -# # with pytest.raises(FileNotFoundError): -# # c.preprocessing_json = "" -# # c.preprocessing_json = None -# # c.preprocessing_dict = {"abc": "abc"} -# # assert c.preprocessing_dict == {"abc": "abc"} - - -def test_model_config(): - with pytest.raises(ValidationError): - NetworkConfig( - **{ - "architecture": "", - "loss": "", - "dropout": 1.1, - } - ) - - -def test_transferlearning_config(): - c = TransferLearningConfig(transfer_path=False) - assert c.transfer_path is None - - -def test_transforms_config(): - c = TransformsConfig(data_augmentation=False) - assert c.data_augmentation == () - c = TransformsConfig(data_augmentation=["Noise"]) - assert c.data_augmentation == ("Noise",) - - -# Global tests on the TrainingConfig class # -@pytest.fixture -def dummy_arguments(caps_example): - args = { - "caps_directory": caps_example, - "preprocessing_json": "preprocessing.json", - "tsv_path": "", - "maps_dir": "", - "gpu": False, - "architecture": "", - "loss": "", - "selection_metrics": (), - } - return args - - -@pytest.fixture -def training_config(): - from pydantic import computed_field - - from clinicadl.trainer.config.train import TrainConfig - - class TrainingConfig(TrainConfig): - @computed_field - @property - def network_task(self) -> str: - return "" - - return TrainingConfig - - -@pytest.fixture( - params=[ - {"gpu": "abc"}, - {"n_splits": -1}, - {"data_augmentation": ("abc",)}, - {"diagnoses": "AD"}, - {"batch_size": 0}, - {"size_reduction_factor": 1}, - {"split": [-1]}, - {"min_delta": -0.01}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture -def good_inputs(dummy_arguments): - options = { - "gpu": False, - "n_splits": 7, - "data_augmentation": ("Smoothing",), - "diagnoses": ("AD",), - "batch_size": 1, - "size_reduction_factor": 5, - "learning_rate": 1e-1, - "split": [0], - "min_delta": 0.0, - } - return {**dummy_arguments, **options} - - -# def test_fails_validations(bad_inputs, training_config): +# from pathlib import Path + +# import pytest +# from pydantic import ValidationError + +# from clinicadl.data.config.data import DataConfig +# from clinicadl.networks.old_network.config import NetworkConfig +# from clinicadl.predictor.validation import ValidationConfig +# from clinicadl.trainer.transfer_learning import TransferLearningConfig +# from clinicadl.transforms.config import TransformsConfig + + +# # Tests for customed validators # +# @pytest.fixture +# def caps_example(): +# dir_ = Path(__file__).parents[2] / "ressources" / "caps_example" +# return dir_ + + +# # def test_cross_validation_config(): +# # c = ValidationConfig( +# # split=[0], +# # tsv_path="", +# # ) +# # assert c.split == (0,) + + +# # def test_data_config(caps_example): +# # c = DataConfig( +# # caps_directory=caps_example, +# # preprocessing_json="preprocessing.json", +# # diagnoses=["AD"], +# # ) +# # expected_preprocessing_dict = { +# # "preprocessing": "t1-linear", +# # "mode": "image", +# # "use_uncropped_image": False, +# # "prepare_dl": False, +# # "extract_json": "t1-linear_mode-image.json", +# # "file_type": { +# # "pattern": "*space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz", +# # "description": "T1W Image registered using t1-linear and cropped (matrix size 169\u00d7208\u00d7179, 1 mm isotropic voxels)", +# # "needed_pipeline": "t1-linear", +# # }, +# # } +# # # assert c.diagnoses == ("AD",) +# # assert ( +# # c.preprocessing_dict == expected_preprocessing_dict +# # ) # TODO : add test for multi-cohort +# # assert c.mode == "image" +# # # with pytest.raises(ValidationError): +# # # c.preprocessing_dict = {"abc": "abc"} +# # # with pytest.raises(FileNotFoundError): +# # # c.preprocessing_json = "" +# # # c.preprocessing_json = None +# # # c.preprocessing_dict = {"abc": "abc"} +# # # assert c.preprocessing_dict == {"abc": "abc"} + + +# def test_model_config(): # with pytest.raises(ValidationError): -# training_config(**bad_inputs) - - -# def test_passes_validations(good_inputs, training_config): -# c = training_config(**good_inputs) +# NetworkConfig( +# **{ +# "architecture": "", +# "loss": "", +# "dropout": 1.1, +# } +# ) + + +# def test_transferlearning_config(): +# c = TransferLearningConfig(transfer_path=False) +# assert c.transfer_path is None + + +# def test_transforms_config(): +# c = TransformsConfig(data_augmentation=False) +# assert c.data_augmentation == () +# c = TransformsConfig(data_augmentation=["Noise"]) +# assert c.data_augmentation == ("Noise",) + + +# # Global tests on the TrainingConfig class # +# @pytest.fixture +# def dummy_arguments(caps_example): +# args = { +# "caps_directory": caps_example, +# "preprocessing_json": "preprocessing.json", +# "tsv_path": "", +# "maps_dir": "", +# "gpu": False, +# "architecture": "", +# "loss": "", +# "selection_metrics": (), +# } +# return args + + +# @pytest.fixture +# def training_config(): +# from pydantic import computed_field + +# from clinicadl.trainer.config.train import TrainConfig + +# class TrainingConfig(TrainConfig): +# @computed_field +# @property +# def network_task(self) -> str: +# return "" + +# return TrainingConfig + + +# @pytest.fixture( +# params=[ +# {"gpu": "abc"}, +# {"n_splits": -1}, +# {"data_augmentation": ("abc",)}, +# {"diagnoses": "AD"}, +# {"batch_size": 0}, +# {"size_reduction_factor": 1}, +# {"split": [-1]}, +# {"min_delta": -0.01}, +# ] +# ) +# def bad_inputs(request, dummy_arguments): +# return {**dummy_arguments, **request.param} + + +# @pytest.fixture +# def good_inputs(dummy_arguments): +# options = { +# "gpu": False, +# "n_splits": 7, +# "data_augmentation": ("Smoothing",), +# "diagnoses": ("AD",), +# "batch_size": 1, +# "size_reduction_factor": 5, +# "learning_rate": 1e-1, +# "split": [0], +# "min_delta": 0.0, +# } +# return {**dummy_arguments, **options} + + +# # def test_fails_validations(bad_inputs, training_config): +# # with pytest.raises(ValidationError): +# # training_config(**bad_inputs) + + +# # def test_passes_validations(good_inputs, training_config): +# # c = training_config(**good_inputs) +# # assert not c.computational.gpu +# # assert c.split.n_splits == 7 +# # assert c.transforms.data_augmentation == ("Smoothing",) +# # # assert c.data.diagnoses == ("AD",) +# # assert c.dataloader.batch_size == 1 +# # assert c.transforms.size_reduction_factor == 5 +# # assert c.split.split == (0,) +# # assert c.early_stopping.min_delta == 0.0 +# # Test config manipulation # +# def test_assignment(dummy_arguments, training_config): +# c = training_config(**dummy_arguments) +# c.computational = {"gpu": False} # assert not c.computational.gpu -# assert c.split.n_splits == 7 -# assert c.transforms.data_augmentation == ("Smoothing",) -# # assert c.data.diagnoses == ("AD",) -# assert c.dataloader.batch_size == 1 -# assert c.transforms.size_reduction_factor == 5 -# assert c.split.split == (0,) -# assert c.early_stopping.min_delta == 0.0 -# Test config manipulation # -def test_assignment(dummy_arguments, training_config): - c = training_config(**dummy_arguments) - c.computational = {"gpu": False} - assert not c.computational.gpu diff --git a/tests/unittests/utils/test_config_utils.py b/tests/unittests/utils/test_config_utils.py index 1ed36b92e..de9559c4c 100644 --- a/tests/unittests/utils/test_config_utils.py +++ b/tests/unittests/utils/test_config_utils.py @@ -32,7 +32,7 @@ class ConfigTest(BaseModel): def test_get_default_from_config_class(): - from clinicadl.config.config_utils import get_default_from_config_class + from clinicadl.utils.config_utils import get_default_from_config_class test_config = ConfigTest() assert get_default_from_config_class("parameter_str", test_config) == "a string" @@ -72,7 +72,7 @@ def test_get_default_from_config_class(): def test_get_type_from_config_class(): - from clinicadl.config.config_utils import get_type_from_config_class + from clinicadl.utils.config_utils import get_type_from_config_class test_config = ConfigTest() assert get_type_from_config_class("parameter_str", test_config) == str