From 1a2135b3675f289eef92b538474fe6752f3c5033 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Wed, 5 Jun 2024 18:45:15 +0200 Subject: [PATCH] first try after rebase and new config --- clinicadl/__init__.py | 2 - clinicadl/caps_dataset/caps_dataset_config.py | 62 ++---- clinicadl/caps_dataset/caps_dataset_utils.py | 68 +++++++ clinicadl/caps_dataset/data.py | 117 +++++++++--- clinicadl/caps_dataset/data_config.py | 16 +- clinicadl/caps_dataset/data_utils.py | 98 ---------- clinicadl/caps_dataset/dataloader_config.py | 9 - clinicadl/commandline/arguments.py | 3 + .../modules_options/preprocessing.py | 2 - .../pipelines/generate/artifacts/cli.py | 4 +- .../pipelines/generate/hypometabolic/cli.py | 6 +- .../pipelines/generate/trivial/cli.py | 4 +- .../prepare_data/prepare_data_cli.py | 178 ++++++++---------- .../prepare_data_from_bids_cli.py | 9 +- clinicadl/config/config/__init__.py | 17 -- clinicadl/generate/generate_config.py | 3 - clinicadl/predict/predict_manager.py | 33 +++- clinicadl/prepare_data/prepare_data.py | 99 +++++----- .../prepare_data_param/__init__.py | 7 - .../prepare_data_param/argument.py | 21 --- .../prepare_data/prepare_data_param/option.py | 104 ---------- .../prepare_data_param/option_patch.py | 30 --- .../prepare_data_param/option_roi.py | 46 ----- .../prepare_data_param/option_slice.py | 47 ----- clinicadl/prepare_data/prepare_data_utils.py | 60 +----- clinicadl/quality_check/t1_linear/utils.py | 2 +- clinicadl/train/resume.py | 2 +- clinicadl/transforms/transforms.py | 32 +--- clinicadl/utils/maps_manager/maps_manager.py | 2 +- .../utils/maps_manager/maps_manager_utils.py | 2 +- clinicadl/utils/meta_maps/getter.py | 2 +- tests/test_prepare_data.py | 90 +++++---- tests/test_resume.py | 2 +- 33 files changed, 423 insertions(+), 756 deletions(-) create mode 100644 clinicadl/caps_dataset/caps_dataset_utils.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/__init__.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/argument.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/option.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/option_patch.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/option_roi.py delete mode 100644 clinicadl/prepare_data/prepare_data_param/option_slice.py diff --git a/clinicadl/__init__.py b/clinicadl/__init__.py index 01ea9db35..4a4fa2381 100644 --- a/clinicadl/__init__.py +++ b/clinicadl/__init__.py @@ -1,7 +1,5 @@ from importlib.metadata import version -from .utils.maps_manager import MapsManager - __all__ = ["__version__", "MapsManager"] __version__ = version("clinicadl") diff --git a/clinicadl/caps_dataset/caps_dataset_config.py b/clinicadl/caps_dataset/caps_dataset_config.py index 1ef9276d2..ab383e722 100644 --- a/clinicadl/caps_dataset/caps_dataset_config.py +++ b/clinicadl/caps_dataset/caps_dataset_config.py @@ -1,23 +1,20 @@ -import abc -from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Union -import pandas as pd -from pydantic import BaseModel, computed_field +from pydantic import BaseModel, ConfigDict -from clinicadl.caps_dataset.data_config import ConfigDict, DataConfig -from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv +from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config import modality +from clinicadl.config.config.modality import ( + CustomModalityConfig, + DTIModalityConfig, + FlairModalityConfig, + ModalityConfig, + PETModalityConfig, + T1ModalityConfig, +) from clinicadl.generate import generate_config as generate_type -from clinicadl.generate.generate_config import GenerateConfig from clinicadl.preprocessing import config as preprocessing from clinicadl.utils.enum import ExtractionMethod, GenerateType, Preprocessing -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLTSVError, - DownloadError, -) def get_preprocessing(extract_method: ExtractionMethod): @@ -38,15 +35,15 @@ def get_modality(preprocessing: Preprocessing): preprocessing == Preprocessing.T1_EXTENSIVE or preprocessing == Preprocessing.T1_LINEAR ): - return modality.T1ModalityConfig + return T1ModalityConfig elif preprocessing == Preprocessing.PET_LINEAR: - return modality.PETModalityConfig + return PETModalityConfig elif preprocessing == Preprocessing.FLAIR_LINEAR: - return modality.FlairModalityConfig + return FlairModalityConfig elif preprocessing == Preprocessing.CUSTOM: - return modality.CustomModalityConfig + return CustomModalityConfig elif preprocessing == Preprocessing.DWI_DTI: - return modality.DTIModalityConfig + return DTIModalityConfig else: raise ValueError(f"Preprocessing {preprocessing.value} is not implemented.") @@ -69,7 +66,8 @@ def get_generate(generate: Union[str, GenerateType]): class CapsDatasetBase(BaseModel): data: DataConfig - modality: modality.ModalityConfig + dataloader: DataLoaderConfig + modality: ModalityConfig preprocessing: preprocessing.PreprocessingConfig # pydantic config @@ -86,29 +84,7 @@ def from_preprocessing_and_extraction_method( ): return cls( data=DataConfig(**kwargs), + dataloader=DataLoaderConfig(**kwargs), modality=get_modality(Preprocessing(preprocessing_type))(**kwargs), preprocessing=get_preprocessing(ExtractionMethod(extraction))(**kwargs), ) - - -# def create_caps_dataset_config( -# preprocessing: Union[str, Preprocessing], extract: Union[str, ExtractionMethod] -# ): -# try: -# preprocessing_type = Preprocessing(preprocessing) -# except ClinicaDLArgumentError: -# print("Invalid preprocessing configuration") - -# try: -# extract_method = ExtractionMethod(extract) -# except ClinicaDLArgumentError: -# print("Invalid preprocessing configuration") - -# class CapsDatasetConfig(CapsDatasetBase): -# modality: get_modality(preprocessing_type) -# preprocessing: get_preprocessing(extract_method) - -# def __init__(self, **kwargs): -# super().__init__(data=kwargs, modality=kwargs, preprocessing=kwargs) - -# return CapsDatasetConfig diff --git a/clinicadl/caps_dataset/caps_dataset_utils.py b/clinicadl/caps_dataset/caps_dataset_utils.py new file mode 100644 index 000000000..275bafaa1 --- /dev/null +++ b/clinicadl/caps_dataset/caps_dataset_utils.py @@ -0,0 +1,68 @@ +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +from pydantic import BaseModel, ConfigDict + +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.utils.enum import LinearModality, Preprocessing + + +def compute_folder_and_file_type( + config: CapsDatasetConfig, from_bids: Optional[Path] = None +) -> Tuple[str, Dict[str, str]]: + from clinicadl.utils.clinica_utils import ( + bids_nii, + dwi_dti, + linear_nii, + pet_linear_nii, + ) + + preprocessing = Preprocessing( + config.preprocessing.preprocessing + ) # replace("-", "_") + if from_bids is not None: + if preprocessing == Preprocessing.CUSTOM: + mod_subfolder = Preprocessing.CUSTOM.value + file_type = { + "pattern": f"*{config.modality.custom_suffix}", + "description": "Custom suffix", + } + else: + mod_subfolder = preprocessing + file_type = bids_nii(preprocessing) + + elif preprocessing not in Preprocessing: + raise NotImplementedError( + f"Extraction of preprocessing {config.preprocessing.preprocessing.value} is not implemented from CAPS directory." + ) + else: + mod_subfolder = preprocessing.value.replace("-", "_") + if preprocessing == Preprocessing.T1_LINEAR: + file_type = linear_nii( + LinearModality.T1W, config.preprocessing.use_uncropped_image + ) + + elif preprocessing == Preprocessing.FLAIR_LINEAR: + file_type = linear_nii( + LinearModality.FLAIR, config.preprocessing.use_uncropped_image + ) + + elif preprocessing == Preprocessing.PET_LINEAR: + file_type = pet_linear_nii( + config.modality.tracer, + config.modality.suvr_reference_region, + config.preprocessing.use_uncropped_image, + ) + elif preprocessing == Preprocessing.DWI_DTI: + file_type = dwi_dti( + config.modality.dti_measure, + config.modality.dti_space, + ) + elif preprocessing == Preprocessing.CUSTOM: + file_type = { + "pattern": f"*{config.modality.custom_suffix}", + "description": "Custom suffix", + } + # custom_suffix["use_uncropped_image"] = None + + return mod_subfolder, file_type diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/caps_dataset/data.py index 5ca88fa2e..d147e1c61 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/caps_dataset/data.py @@ -10,6 +10,7 @@ import torch from torch.utils.data import Dataset +from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type from clinicadl.prepare_data.prepare_data_config import ( PrepareDataConfig, PrepareDataImageConfig, @@ -19,7 +20,6 @@ ) from clinicadl.prepare_data.prepare_data_utils import ( compute_discarded_slices, - compute_folder_and_file_type, extract_patch_path, extract_patch_tensor, extract_roi_path, @@ -142,30 +142,6 @@ def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: def __len__(self) -> int: return len(self.df) * self.elem_per_image - @staticmethod - def create_caps_dict(caps_directory: Path, multi_cohort: bool) -> Dict[str, Path]: - from clinicadl.utils.clinica_utils import check_caps_folder - - if multi_cohort: - if not caps_directory.suffix == ".tsv": - raise ClinicaDLArgumentError( - "If multi_cohort is True, the CAPS_DIRECTORY argument should be a path to a TSV file." - ) - else: - caps_df = pd.read_csv(caps_directory, sep="\t") - check_multi_cohort_tsv(caps_df, "CAPS") - caps_dict = dict() - for idx in range(len(caps_df)): - cohort = caps_df.loc[idx, "cohort"] - caps_path = Path(caps_df.loc[idx, "path"]) - check_caps_folder(caps_path) - caps_dict[cohort] = caps_path - else: - check_caps_folder(caps_directory) - caps_dict = {"single": caps_directory} - - return caps_dict - def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: """ Gets the path to the tensor image (*.pt) @@ -804,3 +780,94 @@ def num_elem_per_image(self): - self.discarded_slices[0] - self.discarded_slices[1] ) + + +def return_dataset( + input_dir: Path, + data_df: pd.DataFrame, + preprocessing_dict: Dict[str, Any], + all_transformations: Optional[Callable], + label: str = None, + label_code: Dict[str, int] = None, + train_transformations: Optional[Callable] = None, + cnn_index: int = None, + label_presence: bool = True, + multi_cohort: bool = False, +) -> CapsDataset: + """ + Return appropriate Dataset according to given options. + Args: + input_dir: path to a directory containing a CAPS structure. + data_df: List subjects, sessions and diagnoses. + preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. + train_transformations: Optional transform to be applied during training only. + all_transformations: Optional transform to be applied during training and evaluation. + label: Name of the column in data_df containing the label. + label_code: label code that links the output node number to label value. + cnn_index: Index of the CNN in a multi-CNN paradigm (optional). + label_presence: If True the diagnosis will be extracted from the given DataFrame. + multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. + + Returns: + the corresponding dataset. + """ + if cnn_index is not None and preprocessing_dict["mode"] == "image": + raise NotImplementedError( + f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." + ) + + if preprocessing_dict["mode"] == "image": + return CapsDatasetImage( + input_dir, + data_df, + preprocessing_dict, + train_transformations=train_transformations, + all_transformations=all_transformations, + label_presence=label_presence, + label=label, + label_code=label_code, + multi_cohort=multi_cohort, + ) + elif preprocessing_dict["mode"] == "patch": + return CapsDatasetPatch( + input_dir, + data_df, + preprocessing_dict, + train_transformations=train_transformations, + all_transformations=all_transformations, + patch_index=cnn_index, + label_presence=label_presence, + label=label, + label_code=label_code, + multi_cohort=multi_cohort, + ) + elif preprocessing_dict["mode"] == "roi": + return CapsDatasetRoi( + input_dir, + data_df, + preprocessing_dict, + train_transformations=train_transformations, + all_transformations=all_transformations, + roi_index=cnn_index, + label_presence=label_presence, + label=label, + label_code=label_code, + multi_cohort=multi_cohort, + ) + elif preprocessing_dict["mode"] == "slice": + return CapsDatasetSlice( + input_dir, + data_df, + preprocessing_dict, + train_transformations=train_transformations, + all_transformations=all_transformations, + slice_index=cnn_index, + label_presence=label_presence, + label=label, + label_code=label_code, + multi_cohort=multi_cohort, + ) + else: + raise NotImplementedError( + f"Mode {preprocessing_dict['mode']} is not implemented." + ) diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/caps_dataset/data_config.py index a5b4daaa7..ce10027c7 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/caps_dataset/data_config.py @@ -1,4 +1,3 @@ -import tarfile from logging import getLogger from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union @@ -8,19 +7,11 @@ from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv, load_data_test from clinicadl.preprocessing.preprocessing import read_preprocessing -from clinicadl.utils.clinica_utils import ( - RemoteFileStructure, - clinicadl_file_reader, - fetch_file, -) -from clinicadl.utils.enum import MaskChecksum, Mode, Pathology +from clinicadl.utils.enum import Mode from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLTSVError, - DownloadError, ) -from clinicadl.utils.maps_manager.maps_manager import MapsManager -from clinicadl.utils.read_utils import get_mask_checksum_and_filename logger = getLogger("clinicadl.data_config") @@ -52,11 +43,6 @@ def validator_diagnoses(cls, v): return tuple(v) return v # TODO : check if columns are in tsv - def adapt_data_with_maps_manager_info(self, maps_manager: MapsManager): - # TEMPORARY - if self.diagnoses is None or len(self.diagnoses) == 0: - self.diagnoses = maps_manager.diagnoses - def create_groupe_df(self): group_df = None if self.data_tsv is not None and self.data_tsv.is_file(): diff --git a/clinicadl/caps_dataset/data_utils.py b/clinicadl/caps_dataset/data_utils.py index f05d09a69..08a49cd9f 100644 --- a/clinicadl/caps_dataset/data_utils.py +++ b/clinicadl/caps_dataset/data_utils.py @@ -7,13 +7,6 @@ import pandas as pd -from clinicadl.caps_dataset.data import ( - CapsDataset, - CapsDatasetImage, - CapsDatasetPatch, - CapsDatasetRoi, - CapsDatasetSlice, -) from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLConfigurationError, @@ -23,97 +16,6 @@ logger = getLogger("clinicadl") -def return_dataset( - input_dir: Path, - data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - all_transformations: Optional[Callable], - label: str = None, - label_code: Dict[str, int] = None, - train_transformations: Optional[Callable] = None, - cnn_index: int = None, - label_presence: bool = True, - multi_cohort: bool = False, -) -> CapsDataset: - """ - Return appropriate Dataset according to given options. - Args: - input_dir: path to a directory containing a CAPS structure. - data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied during training only. - all_transformations: Optional transform to be applied during training and evaluation. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - cnn_index: Index of the CNN in a multi-CNN paradigm (optional). - label_presence: If True the diagnosis will be extracted from the given DataFrame. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - Returns: - the corresponding dataset. - """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": - raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." - ) - - if preprocessing_dict["mode"] == "image": - return CapsDatasetImage( - input_dir, - data_df, - preprocessing_dict, - train_transformations=train_transformations, - all_transformations=all_transformations, - label_presence=label_presence, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - elif preprocessing_dict["mode"] == "patch": - return CapsDatasetPatch( - input_dir, - data_df, - preprocessing_dict, - train_transformations=train_transformations, - all_transformations=all_transformations, - patch_index=cnn_index, - label_presence=label_presence, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - elif preprocessing_dict["mode"] == "roi": - return CapsDatasetRoi( - input_dir, - data_df, - preprocessing_dict, - train_transformations=train_transformations, - all_transformations=all_transformations, - roi_index=cnn_index, - label_presence=label_presence, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - elif preprocessing_dict["mode"] == "slice": - return CapsDatasetSlice( - input_dir, - data_df, - preprocessing_dict, - train_transformations=train_transformations, - all_transformations=all_transformations, - slice_index=cnn_index, - label_presence=label_presence, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - else: - raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." - ) - - ################################ # TSV files loaders ################################ diff --git a/clinicadl/caps_dataset/dataloader_config.py b/clinicadl/caps_dataset/dataloader_config.py index e2c02afa1..cc01ba9a9 100644 --- a/clinicadl/caps_dataset/dataloader_config.py +++ b/clinicadl/caps_dataset/dataloader_config.py @@ -4,7 +4,6 @@ from pydantic.types import PositiveInt from clinicadl.utils.enum import Sampler -from clinicadl.utils.maps_manager.maps_manager import MapsManager logger = getLogger("clinicadl.dataloader_config") @@ -17,11 +16,3 @@ class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module sampler: Sampler = Sampler.RANDOM # pydantic config model_config = ConfigDict(validate_assignment=True) - - def adapt_dataloader_with_maps_manager_info(self, maps_manager: MapsManager): - # TEMPORARY - if not self.batch_size: - self.batch_size = maps_manager.batch_size - - if not self.n_proc: - self.n_proc = maps_manager.n_proc diff --git a/clinicadl/commandline/arguments.py b/clinicadl/commandline/arguments.py index 00664eb46..2d85b1fb0 100644 --- a/clinicadl/commandline/arguments.py +++ b/clinicadl/commandline/arguments.py @@ -51,3 +51,6 @@ config_file = click.argument( "config_file", type=click.Path(exists=True, path_type=Path) ) +preprocessing = click.argument( + "preprocessing", type=click.Choice(["t1", "pet", "flair", "dwi", "custom"]) +) diff --git a/clinicadl/commandline/modules_options/preprocessing.py b/clinicadl/commandline/modules_options/preprocessing.py index a33393bba..d7840ed62 100644 --- a/clinicadl/commandline/modules_options/preprocessing.py +++ b/clinicadl/commandline/modules_options/preprocessing.py @@ -92,8 +92,6 @@ the end of the MRI volume. If only one argument is given, it will be used for both sides.""", ) - - roi_list = click.option( "--roi_list", type=get_type("roi_list", PreprocessingROIConfig), diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index 4711a2026..d5fddf867 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -47,7 +47,7 @@ @artifacts.translation @artifacts.rotation @artifacts.gamma -def cli(generated_caps_directory, n_proc, **kwargs): +def cli(generated_caps_directory, **kwargs): """ Addition of artifacts (noise, motion or contrast) to brain images @@ -158,7 +158,7 @@ def create_artifacts_image(data_idx: int) -> pd.DataFrame: return row_df - results_df = Parallel(n_jobs=n_proc)( + results_df = Parallel(n_jobs=caps_config.dataloader.n_proc)( delayed(create_artifacts_image)(data_idx) for data_idx in range(len(data_df)) ) output_df = pd.DataFrame() diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index 4bfe095c9..d993247c4 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -42,7 +42,7 @@ @hypometabolic.sigma @hypometabolic.anomaly_degree @hypometabolic.pathology -def cli(generated_caps_directory, n_proc, **kwargs): +def cli(generated_caps_directory, **kwargs): """Generation of trivial dataset with addition of synthetic brain atrophy. CAPS_DIRECTORY is the CAPS folder from where input brain images will be loaded. GENERATED_CAPS_DIRECTORY is a CAPS folder where the trivial dataset will be saved. @@ -62,7 +62,7 @@ def cli(generated_caps_directory, n_proc, **kwargs): "caps_dir": caps_config.data.caps_directory, "preprocessing": caps_config.preprocessing.preprocessing.value, "n_subjects": caps_config.data.n_subjects, - "n_proc": n_proc, + "n_proc": caps_config.dataloader.n_proc, "pathology": generate_config.pathology.value, "anomaly_degree": generate_config.anomaly_degree, } @@ -139,7 +139,7 @@ def generate_hypometabolic_image( row_df = pd.DataFrame([row], columns=columns) return row_df - results_list = Parallel(n_jobs=n_proc)( + results_list = Parallel(n_jobs=caps_config.dataloader.n_proc)( delayed(generate_hypometabolic_image)(subject_id) for subject_id in range(caps_config.data.n_subjects) ) diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index b8e4aece6..67188a7bf 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -44,7 +44,7 @@ @modality.suvr_reference_region @trivial.atrophy_percent @data.mask_path -def cli(generated_caps_directory, n_proc, **kwargs): +def cli(generated_caps_directory, **kwargs): """Generation of a trivial dataset""" caps_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( @@ -151,7 +151,7 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame: return row_df - results_df = Parallel(n_jobs=n_proc)( + results_df = Parallel(n_jobs=caps_config.dataloader.n_proc)( delayed(create_trivial_image)(subject_id) for subject_id in range(2 * caps_config.data.n_subjects) ) diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py index 540789d0e..57a266d33 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py @@ -3,7 +3,14 @@ import click -from clinicadl.prepare_data import prepare_data_param +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.commandline import arguments +from clinicadl.commandline.modules_options import ( + data, + dataloader, + modality, + preprocessing, +) from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData from clinicadl.prepare_data.prepare_data_config import ( PrepareDataImageConfig, @@ -24,36 +31,28 @@ @click.command(name="image", no_args_is_help=True) -@prepare_data_param.argument.caps_directory -@prepare_data_param.argument.preprocessing -@prepare_data_param.option.n_proc -@prepare_data_param.option.tsv_file -@prepare_data_param.option.extract_json -@prepare_data_param.option.use_uncropped_image -@prepare_data_param.option.tracer -@prepare_data_param.option.suvr_reference_region -@prepare_data_param.option.custom_suffix -@prepare_data_param.option.dti_measure -@prepare_data_param.option.dti_space -def image_cli( - caps_directory: Path, - preprocessing: Preprocessing, - **kwargs, -): +@arguments.caps_directory +@arguments.preprocessing +@dataloader.n_proc +@data.participants_tsv +@preprocessing.extract_json +@preprocessing.use_uncropped_image +@modality.tracer +@modality.suvr_reference_region +@modality.custom_suffix +@modality.dti_measure +@modality.dti_space +def image_cli(**kwargs): """Extract image from nifti images. CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. """ - image_config = PrepareDataImageConfig( - caps_directory=caps_directory, - preprocessing_cls=preprocessing, - tracer_cls=kwargs["tracer"], - suvr_reference_region_cls=kwargs["suvr_reference_region"], - dti_measure_cls=kwargs["dti_measure"], - dti_space_cls=kwargs["dti_space"], - save_features=True, + kwargs["save_features"] = True + image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.IMAGE, + preprocessing_type=kwargs["preprocessing"], **kwargs, ) @@ -61,21 +60,21 @@ def image_cli( @click.command(name="patch", no_args_is_help=True) -@prepare_data_param.argument.caps_directory -@prepare_data_param.argument.preprocessing -@prepare_data_param.option.n_proc -@prepare_data_param.option.save_features -@prepare_data_param.option.tsv_file -@prepare_data_param.option.extract_json -@prepare_data_param.option.use_uncropped_image -@prepare_data_param.option.tracer -@prepare_data_param.option.suvr_reference_region -@prepare_data_param.option.custom_suffix -@prepare_data_param.option.dti_measure -@prepare_data_param.option.dti_space -@prepare_data_param.option_patch.patch_size -@prepare_data_param.option_patch.stride_size -def patch_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): +@arguments.caps_directory +@arguments.preprocessing +@dataloader.n_proc +@preprocessing.save_features +@data.participants_tsv +@preprocessing.extract_json +@preprocessing.use_uncropped_image +@modality.tracer +@modality.suvr_reference_region +@modality.custom_suffix +@modality.dti_measure +@modality.dti_space +@preprocessing.patch_size +@preprocessing.stride_size +def patch_cli(**kwargs): """Extract patch from nifti images. CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. @@ -83,13 +82,9 @@ def patch_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. """ - patch_config = PrepareDataPatchConfig( - caps_directory=caps_directory, - preprocessing_cls=preprocessing, - tracer_cls=kwargs["tracer"], - suvr_reference_region_cls=kwargs["suvr_reference_region"], - dti_measure_cls=kwargs["dti_measure"], - dti_space_cls=kwargs["dti_space"], + patch_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.PATCH, + preprocessing_type=kwargs["preprocessing"], **kwargs, ) @@ -97,38 +92,31 @@ def patch_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): @click.command(name="slice", no_args_is_help=True) -@prepare_data_param.argument.caps_directory -@prepare_data_param.argument.preprocessing -@prepare_data_param.option.n_proc -@prepare_data_param.option.save_features -@prepare_data_param.option.tsv_file -@prepare_data_param.option.extract_json -@prepare_data_param.option.use_uncropped_image -@prepare_data_param.option.tracer -@prepare_data_param.option.suvr_reference_region -@prepare_data_param.option.custom_suffix -@prepare_data_param.option.dti_measure -@prepare_data_param.option.dti_space -@prepare_data_param.option_slice.slice_method -@prepare_data_param.option_slice.slice_direction -@prepare_data_param.option_slice.discarded_slice -def slice_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): +@arguments.caps_directory +@arguments.preprocessing +@dataloader.n_proc +@preprocessing.save_features +@data.participants_tsv +@preprocessing.extract_json +@preprocessing.use_uncropped_image +@modality.tracer +@modality.suvr_reference_region +@modality.custom_suffix +@modality.dti_measure +@modality.dti_space +@preprocessing.slice_mode +@preprocessing.slice_direction +@preprocessing.discarded_slices +def slice_cli(**kwargs): """Extract slice from nifti images. CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. """ - - slice_config = PrepareDataSliceConfig( - caps_directory=caps_directory, - preprocessing_cls=preprocessing, - tracer_cls=kwargs["tracer"], - suvr_reference_region_cls=kwargs["suvr_reference_region"], - dti_measure_cls=kwargs["dti_measure"], - dti_space_cls=kwargs["dti_space"], - slice_direction_cls=kwargs["slice_direction"], - slice_mode_cls=kwargs["slice_mode"], + slice_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.SLICE, + preprocessing_type=kwargs["preprocessing"], **kwargs, ) @@ -136,23 +124,23 @@ def slice_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): @click.command(name="roi", no_args_is_help=True) -@prepare_data_param.argument.caps_directory -@prepare_data_param.argument.preprocessing -@prepare_data_param.option.n_proc -@prepare_data_param.option.save_features -@prepare_data_param.option.tsv_file -@prepare_data_param.option.extract_json -@prepare_data_param.option.use_uncropped_image -@prepare_data_param.option.tracer -@prepare_data_param.option.suvr_reference_region -@prepare_data_param.option.custom_suffix -@prepare_data_param.option.dti_measure -@prepare_data_param.option.dti_space -@prepare_data_param.option_roi.roi_list -@prepare_data_param.option_roi.roi_uncrop_output -@prepare_data_param.option_roi.roi_custom_template -@prepare_data_param.option_roi.roi_custom_mask_pattern -def roi_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): +@arguments.caps_directory +@arguments.preprocessing +@dataloader.n_proc +@preprocessing.save_features +@data.participants_tsv +@preprocessing.extract_json +@preprocessing.use_uncropped_image +@modality.tracer +@modality.suvr_reference_region +@modality.custom_suffix +@modality.dti_measure +@modality.dti_space +@preprocessing.roi_list +@preprocessing.roi_uncrop_output +@preprocessing.roi_custom_template +@preprocessing.roi_custom_mask_pattern +def roi_cli(**kwargs): """Extract roi from nifti images. CAPS_DIRECTORY is the CAPS folder where nifti images are stored and tensor will be saved. @@ -160,13 +148,9 @@ def roi_cli(caps_directory: Path, preprocessing: Preprocessing, **kwargs): PREPROCESSING [t1-linear|pet-linear|custom] is the clinica pipeline name used for image preprocessing. """ - roi_config = PrepareDataROIConfig( - caps_directory=caps_directory, - preprocessing_cls=preprocessing, - tracer_cls=kwargs["tracer"], - suvr_reference_region_cls=kwargs["suvr_reference_region"], - dti_measure_cls=kwargs["dti_measure"], - dti_space_cls=kwargs["dti_space"], + roi_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.ROI, + preprocessing_type=kwargs["preprocessing"], **kwargs, ) diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py index 472a4bca8..9c94e90ed 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py @@ -3,8 +3,13 @@ import click -from clinicadl.config import arguments -from clinicadl.config.options import data, dataloader, modality, preprocessing +from clinicadl.commandline import arguments +from clinicadl.commandline.modules_options import ( + data, + dataloader, + modality, + preprocessing, +) from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData diff --git a/clinicadl/config/config/__init__.py b/clinicadl/config/config/__init__.py index 3d83fede8..e69de29bb 100644 --- a/clinicadl/config/config/__init__.py +++ b/clinicadl/config/config/__init__.py @@ -1,17 +0,0 @@ -from ...network.config import NetworkConfig -from ...transforms.config import TransformsConfig -from .computational import ComputationalConfig -from .cross_validation import CrossValidationConfig -from .early_stopping import EarlyStoppingConfig -from .lr_scheduler import LRschedulerConfig -from .maps_manager import MapsManagerConfig -from .modality import ( - CustomModalityConfig, - DTIModalityConfig, - ModalityConfig, - PETModalityConfig, -) -from .reproducibility import ReproducibilityConfig -from .ssda import SSDAConfig -from .transfer_learning import TransferLearningConfig -from .validation import ValidationConfig diff --git a/clinicadl/generate/generate_config.py b/clinicadl/generate/generate_config.py index e1ce2c500..a276d5a93 100644 --- a/clinicadl/generate/generate_config.py +++ b/clinicadl/generate/generate_config.py @@ -14,9 +14,6 @@ field_validator, ) -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig -from clinicadl.config.config import ModalityConfig -from clinicadl.preprocessing.config import PreprocessingConfig from clinicadl.utils.clinica_utils import ( RemoteFileStructure, clinicadl_file_reader, diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index eddd2bea8..0b0368ac2 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -93,8 +93,21 @@ def predict( assert isinstance(self._config, PredictConfig) self._config.check_output_saving_nifti(self.maps_manager.network_task) - self._config.adapt_data_with_maps_manager_info(self.maps_manager) - self._config.adapt_dataloader_with_maps_manager_info(self.maps_manager) + self._config.diagnoses = ( + self.maps_manager.diagnoses + if self._config.diagnoses is None or len(self._config.diagnoses) == 0 + else self._config.diagnoses + ) + + self._config.batch_size = ( + self.maps_manager.batch_size + if not self._config.batch_size + else self._config.batch_size + ) + self._config.n_proc = ( + self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc + ) + self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) self._config.check_output_saving_tensor(self.maps_manager.network_task) @@ -636,8 +649,20 @@ def interpret(self): """ assert isinstance(self._config, InterpretConfig) - self._config.adapt_data_with_maps_manager_info(self.maps_manager) - self._config.adapt_dataloader_with_maps_manager_info(self.maps_manager) + self._config.diagnoses = ( + self.maps_manager.diagnoses + if self._config.diagnoses is None or len(self._config.diagnoses) == 0 + else self._config.diagnoses + ) + self._config.batch_size = ( + self.maps_manager.batch_size + if not self._config.batch_size + else self._config.batch_size + ) + self._config.n_proc = ( + self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc + ) + self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) if self.maps_manager.multi_network: diff --git a/clinicadl/prepare_data/prepare_data.py b/clinicadl/prepare_data/prepare_data.py index 74f07f2f0..d8f88e044 100644 --- a/clinicadl/prepare_data/prepare_data.py +++ b/clinicadl/prepare_data/prepare_data.py @@ -5,11 +5,14 @@ from joblib import Parallel, delayed from torch import save as save_tensor -from clinicadl.prepare_data.prepare_data_config import ( - PrepareDataConfig, - PrepareDataPatchConfig, - PrepareDataROIConfig, - PrepareDataSliceConfig, +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type +from clinicadl.preprocessing.config import ( + PreprocessingConfig, + PreprocessingImageConfig, + PreprocessingPatchConfig, + PreprocessingROIConfig, + PreprocessingSliceConfig, ) from clinicadl.preprocessing.preprocessing import write_preprocessing from clinicadl.utils.clinica_utils import ( @@ -21,11 +24,11 @@ from clinicadl.utils.enum import ExtractionMethod, Pattern, Preprocessing, Template from clinicadl.utils.exceptions import ClinicaDLArgumentError -from .prepare_data_utils import check_mask_list, compute_folder_and_file_type +from .prepare_data_utils import check_mask_list def DeepLearningPrepareData( - config: PrepareDataConfig, from_bids: Optional[Path] = None + config: CapsDatasetConfig, from_bids: Optional[Path] = None ): logger = getLogger("clinicadl.prepare_data") # Get subject and session list @@ -37,25 +40,25 @@ def DeepLearningPrepareData( logger.debug(f"BIDS directory: {input_directory}.") is_bids_dir = True else: - input_directory = config.caps_directory + input_directory = config.data.caps_directory check_caps_folder(input_directory) logger.debug(f"CAPS directory: {input_directory}.") is_bids_dir = False subjects, sessions = get_subject_session_list( - input_directory, config.tsv_file, is_bids_dir, False, None + input_directory, config.data.data_tsv, is_bids_dir, False, None ) - if config.save_features: + if config.preprocessing.save_features: logger.info( - f"{config.extract_method.value}s will be extracted in Pytorch tensor from {len(sessions)} images." + f"{config.preprocessing.extract_method.value}s will be extracted in Pytorch tensor from {len(sessions)} images." ) else: logger.info( f"Images will be extracted in Pytorch tensor from {len(sessions)} images." ) logger.info( - f"Information for {config.extract_method.value} will be saved in output JSON file and will be used " + f"Information for {config.preprocessing.extract_method.value} will be saved in output JSON file and will be used " f"during training for on-the-fly extraction." ) logger.debug(f"List of subjects: \n{subjects}.") @@ -80,7 +83,7 @@ def write_output_imgs(output_mode, container, subfolder): # Write the extracted tensor on a .pt file for filename, tensor in output_mode: output_file_dir = ( - config.caps_directory + config.data.caps_directory / container / "deeplearning_prepare_data" / subfolder @@ -91,7 +94,10 @@ def write_output_imgs(output_mode, container, subfolder): save_tensor(tensor, output_file) logger.debug(f"Output tensor saved at {output_file}") - if config.extract_method == ExtractionMethod.IMAGE or not config.save_features: + if ( + config.preprocessing.extract_method == ExtractionMethod.IMAGE + or not config.preprocessing.save_features + ): def prepare_image(file): from .prepare_data_utils import extract_images @@ -103,117 +109,122 @@ def prepare_image(file): logger.debug("Image extracted.") write_output_imgs(output_mode, container, subfolder) - Parallel(n_jobs=config.n_proc)( + Parallel(n_jobs=config.dataloader.n_proc)( delayed(prepare_image)(file) for file in input_files ) - elif config.save_features: - if config.extract_method == ExtractionMethod.SLICE: - assert isinstance(config, PrepareDataSliceConfig) + elif config.preprocessing.save_features: + if config.preprocessing.extract_method == ExtractionMethod.SLICE: + assert isinstance(config.preprocessing, PreprocessingSliceConfig) def prepare_slice(file): from .prepare_data_utils import extract_slices + assert isinstance(config.preprocessing, PreprocessingSliceConfig) logger.debug(f" Processing of {file}.") container = container_from_filename(file) subfolder = "slice_based" output_mode = extract_slices( Path(file), - slice_direction=config.slice_direction, - slice_mode=config.slice_mode, - discarded_slices=config.discarded_slices, + slice_direction=config.preprocessing.slice_direction, + slice_mode=config.preprocessing.slice_mode, + discarded_slices=config.preprocessing.discarded_slices, ) logger.debug(f" {len(output_mode)} slices extracted.") write_output_imgs(output_mode, container, subfolder) - Parallel(n_jobs=config.n_proc)( + Parallel(n_jobs=config.dataloader.n_proc)( delayed(prepare_slice)(file) for file in input_files ) - elif config.extract_method == ExtractionMethod.PATCH: - assert isinstance(config, PrepareDataPatchConfig) + elif config.preprocessing.extract_method == ExtractionMethod.PATCH: + assert isinstance(config.preprocessing, PreprocessingPatchConfig) def prepare_patch(file): from .prepare_data_utils import extract_patches + assert isinstance(config.preprocessing, PreprocessingPatchConfig) logger.debug(f" Processing of {file}.") container = container_from_filename(file) subfolder = "patch_based" output_mode = extract_patches( Path(file), - patch_size=config.patch_size, - stride_size=config.stride_size, + patch_size=config.preprocessing.patch_size, + stride_size=config.preprocessing.stride_size, ) logger.debug(f" {len(output_mode)} patches extracted.") write_output_imgs(output_mode, container, subfolder) - Parallel(n_jobs=config.n_proc)( + Parallel(n_jobs=config.dataloader.n_proc)( delayed(prepare_patch)(file) for file in input_files ) - elif config.extract_method == ExtractionMethod.ROI: - assert isinstance(config, PrepareDataROIConfig) + elif config.preprocessing.extract_method == ExtractionMethod.ROI: + assert isinstance(config.preprocessing, PreprocessingROIConfig) def prepare_roi(file): from .prepare_data_utils import extract_roi + assert isinstance(config.preprocessing, PreprocessingROIConfig) logger.debug(f" Processing of {file}.") container = container_from_filename(file) subfolder = "roi_based" if config.preprocessing == Preprocessing.CUSTOM: - if not config.roi_custom_template: + if not config.preprocessing.roi_custom_template: raise ClinicaDLArgumentError( "A custom template must be defined when the modality is set to custom." ) - roi_template = config.roi_custom_template - roi_mask_pattern = config.roi_custom_mask_pattern + roi_template = config.preprocessing.roi_custom_template + roi_mask_pattern = config.preprocessing.roi_custom_mask_pattern else: - if config.preprocessing == Preprocessing.T1_LINEAR: + if config.preprocessing.preprocessing == Preprocessing.T1_LINEAR: roi_template = Template.T1_LINEAR roi_mask_pattern = Pattern.T1_LINEAR - elif config.preprocessing == Preprocessing.PET_LINEAR: + elif config.preprocessing.preprocessing == Preprocessing.PET_LINEAR: roi_template = Template.PET_LINEAR roi_mask_pattern = Pattern.PET_LINEAR - elif config.preprocessing == Preprocessing.FLAIR_LINEAR: + elif ( + config.preprocessing.preprocessing == Preprocessing.FLAIR_LINEAR + ): roi_template = Template.FLAIR_LINEAR roi_mask_pattern = Pattern.FLAIR_LINEAR masks_location = input_directory / "masks" / f"tpl-{roi_template}" - if len(config.roi_list) == 0: + if len(config.preprocessing.roi_list) == 0: raise ClinicaDLArgumentError( "A list of regions of interest must be given." ) else: check_mask_list( masks_location, - config.roi_list, + config.preprocessing.roi_list, roi_mask_pattern, - config.use_uncropped_image, + config.preprocessing.use_uncropped_image, ) output_mode = extract_roi( Path(file), masks_location=masks_location, mask_pattern=roi_mask_pattern, - cropped_input=not config.use_uncropped_image, - roi_names=config.roi_list, - uncrop_output=config.roi_uncrop_output, + cropped_input=not config.preprocessing.use_uncropped_image, + roi_names=config.preprocessing.roi_list, + uncrop_output=config.preprocessing.roi_uncrop_output, ) logger.debug("ROI extracted.") write_output_imgs(output_mode, container, subfolder) - Parallel(n_jobs=config.n_proc)( + Parallel(n_jobs=config.dataloader.n_proc)( delayed(prepare_roi)(file) for file in input_files ) else: raise NotImplementedError( - f"Extraction is not implemented for mode {config.extract_method.value}." + f"Extraction is not implemented for mode {config.preprocessing.extract_method.value}." ) # Save parameters dictionary preprocessing_json_path = write_preprocessing( - config.model_dump(), config.caps_directory + config.preprocessing.model_dump(), config.data.caps_directory ) logger.info(f"Preprocessing JSON saved at {preprocessing_json_path}.") diff --git a/clinicadl/prepare_data/prepare_data_param/__init__.py b/clinicadl/prepare_data/prepare_data_param/__init__.py deleted file mode 100644 index 12b35b5d1..000000000 --- a/clinicadl/prepare_data/prepare_data_param/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from . import ( - argument, - option, - option_patch, - option_roi, - option_slice, -) diff --git a/clinicadl/prepare_data/prepare_data_param/argument.py b/clinicadl/prepare_data/prepare_data_param/argument.py deleted file mode 100644 index bce68821e..000000000 --- a/clinicadl/prepare_data/prepare_data_param/argument.py +++ /dev/null @@ -1,21 +0,0 @@ -from pathlib import Path - -import click - -from clinicadl.prepare_data.prepare_data_config import PrepareDataConfig -from clinicadl.utils.enum import ( - Preprocessing, - SUVRReferenceRegions, - Tracer, -) - -config = PrepareDataConfig.model_fields - -caps_directory = click.argument( - "caps_directory", - type=config["caps_directory"].annotation, -) -preprocessing = click.argument( - "preprocessing", - type=click.Choice(Preprocessing), -) diff --git a/clinicadl/prepare_data/prepare_data_param/option.py b/clinicadl/prepare_data/prepare_data_param/option.py deleted file mode 100644 index 51ea70c7a..000000000 --- a/clinicadl/prepare_data/prepare_data_param/option.py +++ /dev/null @@ -1,104 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl.prepare_data.prepare_data_config import PrepareDataConfig -from clinicadl.utils.enum import ( - DTIMeasure, - DTISpace, - Preprocessing, - SUVRReferenceRegions, - Tracer, -) - -config = PrepareDataConfig.model_fields - -n_proc = click.option( - "-np", - "--n_proc", - type=config["n_proc"].annotation, - default=config["n_proc"].default, - show_default=True, - help="Number of cores used during the task.", -) -tsv_file = click.option( - "--participants_tsv", - type=get_args(config["tsv_file"].annotation)[0], - default=config["tsv_file"].default, - help="Path to a TSV file including a list of participants/sessions.", - show_default=True, -) -extract_json = click.option( - "-ej", - "--extract_json", - type=get_args(config["extract_json"].annotation)[0], - default=config["extract_json"].default, - help="Name of the JSON file created to describe the tensor extraction. " - "Default will use format extract_{time_stamp}.json", -) -use_uncropped_image = click.option( - "-uui", - "--use_uncropped_image", - is_flag=True, - help="Use the uncropped image instead of the cropped image generated by t1-linear or pet-linear.", - show_default=True, -) -tracer = click.option( - "--tracer", - type=click.Choice(Tracer), - default=config["tracer_cls"].default.value, - help=( - "Acquisition label if PREPROCESSING is `pet-linear`. " - "Name of the tracer used for the PET acquisition (trc-). " - "For instance it can be '18FFDG' for fluorodeoxyglucose or '18FAV45' for florbetapir." - ), - show_default=True, -) -suvr_reference_region = click.option( - "-suvr", - "--suvr_reference_region", - type=click.Choice(SUVRReferenceRegions), - default=config["suvr_reference_region_cls"].default.value, - help=( - "Regions used for normalization if PREPROCESSING is `pet-linear`. " - "Intensity normalization using the average PET uptake in reference regions resulting in a standardized uptake " - "value ratio (SUVR) map. It can be cerebellumPons or cerebellumPon2 (used for amyloid tracers) or pons or " - "pons2 (used for 18F-FDG tracers)." - ), - show_default=True, -) -custom_suffix = click.option( - "-cn", - "--custom_suffix", - type=config["custom_suffix"].annotation, - default=config["custom_suffix"].default, - help=( - "Suffix of output files if PREPROCESSING is `custom`. " - "Suffix to append to filenames, for instance " - "`graymatter_space-Ixi549Space_modulated-off_probability.nii.gz`, or " - "`segm-whitematter_probability.nii.gz`" - ), -) -dti_measure = click.option( - "--dti_measure", - "-dm", - type=click.Choice(DTIMeasure), - help="Possible DTI measures.", - default=config["dti_measure_cls"].default.value, - show_default=True, -) -dti_space = click.option( - "--dti_space", - "-ds", - type=click.Choice(DTISpace), - help="Possible DTI space.", - default=config["dti_space_cls"].default.value, - show_default=True, -) -save_features = click.option( - "--save_features", - is_flag=True, - help="""Extract the selected mode to save the tensor. By default, the pipeline only save images and the mode extraction - is done when images are loaded in the train.""", -) diff --git a/clinicadl/prepare_data/prepare_data_param/option_patch.py b/clinicadl/prepare_data/prepare_data_param/option_patch.py deleted file mode 100644 index 4e1c5ee5a..000000000 --- a/clinicadl/prepare_data/prepare_data_param/option_patch.py +++ /dev/null @@ -1,30 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl.prepare_data.prepare_data_config import PrepareDataPatchConfig -from clinicadl.utils.enum import ( - DTIMeasure, - DTISpace, - Preprocessing, - SUVRReferenceRegions, - Tracer, -) - -config = PrepareDataPatchConfig.model_fields - -patch_size = click.option( - "-ps", - "--patch_size", - default=50, - show_default=True, - help="Patch size.", -) -stride_size = click.option( - "-ss", - "--stride_size", - default=50, - show_default=True, - help="Stride size.", -) diff --git a/clinicadl/prepare_data/prepare_data_param/option_roi.py b/clinicadl/prepare_data/prepare_data_param/option_roi.py deleted file mode 100644 index 58c6d575e..000000000 --- a/clinicadl/prepare_data/prepare_data_param/option_roi.py +++ /dev/null @@ -1,46 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl.prepare_data.prepare_data_config import PrepareDataROIConfig -from clinicadl.utils.enum import ( - DTIMeasure, - DTISpace, - Preprocessing, - SUVRReferenceRegions, - Tracer, -) - -config = PrepareDataROIConfig.model_fields - -roi_list = click.option( - "--roi_list", - type=get_args(config["roi_list"].annotation)[0], - default=config["roi_list"].default, - multiple=True, - help="List of regions to be extracted", -) -roi_uncrop_output = click.option( - "--roi_uncrop_output", - is_flag=True, - help="Disable cropping option so the output tensors " - "have the same size than the whole image.", -) -roi_custom_template = click.option( - "--roi_custom_template", - "-ct", - type=config["roi_custom_template"].annotation, - default=config["roi_custom_template"].default, - help="""Template name if MODALITY is `custom`. - Name of the template used for registration during the preprocessing procedure.""", -) -roi_custom_mask_pattern = click.option( - "--roi_custom_mask_pattern", - "-cmp", - type=config["roi_custom_mask_pattern"].annotation, - default=config["roi_custom_mask_pattern"].default, - help="""Mask pattern if MODALITY is `custom`. - If given will select only the masks containing the string given. - The mask with the shortest name is taken.""", -) diff --git a/clinicadl/prepare_data/prepare_data_param/option_slice.py b/clinicadl/prepare_data/prepare_data_param/option_slice.py deleted file mode 100644 index 611625030..000000000 --- a/clinicadl/prepare_data/prepare_data_param/option_slice.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl.prepare_data.prepare_data_config import PrepareDataSliceConfig -from clinicadl.utils.enum import ( - DTIMeasure, - DTISpace, - Preprocessing, - SliceDirection, - SliceMode, - SUVRReferenceRegions, - Tracer, -) - -config = PrepareDataSliceConfig.model_fields - -slice_direction = click.option( - "-sd", - "--slice_direction", - type=click.Choice(SliceDirection), - default=config["slice_direction_cls"].default.value, - show_default=True, - help="Slice direction. 0: Sagittal plane, 1: Coronal plane, 2: Axial plane.", -) -slice_method = click.option( - "-sm", - "--slice_mode", - type=click.Choice(SliceMode), - default=config["slice_mode_cls"].default.value, - show_default=True, - help=( - "rgb: Save the slice in three identical channels, " - "single: Save the slice in a single channel." - ), -) -discarded_slice = click.option( - "-ds", - "--discarded_slices", - type=get_args(config["discarded_slices"].annotation)[0], - default=config["discarded_slices"].default, - multiple=2, - help="""Number of slices discarded from respectively the beginning and - the end of the MRI volume. If only one argument is given, it will be - used for both sides.""", -) diff --git a/clinicadl/prepare_data/prepare_data_utils.py b/clinicadl/prepare_data/prepare_data_utils.py index 4bb661618..2eb7c7048 100644 --- a/clinicadl/prepare_data/prepare_data_utils.py +++ b/clinicadl/prepare_data/prepare_data_utils.py @@ -6,73 +6,15 @@ import numpy as np import torch -from clinicadl.prepare_data.prepare_data_config import PrepareDataConfig +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.utils.enum import ( - BIDSModality, LinearModality, Preprocessing, SliceDirection, SliceMode, - SUVRReferenceRegions, - Tracer, ) -def compute_folder_and_file_type( - config: PrepareDataConfig, from_bids: Optional[Path] = None -) -> Tuple[str, Dict[str, str]]: - from clinicadl.utils.clinica_utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, - ) - - preprocessing = Preprocessing(config.preprocessing) # replace("-", "_") - if from_bids is not None: - if preprocessing == Preprocessing.CUSTOM: - mod_subfolder = Preprocessing.CUSTOM.value - file_type = { - "pattern": f"*{config.custom_suffix}", - "description": "Custom suffix", - } - else: - mod_subfolder = preprocessing - file_type = bids_nii(preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {config.preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if preprocessing == Preprocessing.T1_LINEAR: - file_type = linear_nii(LinearModality.T1W, config.use_uncropped_image) - - elif preprocessing == Preprocessing.FLAIR_LINEAR: - file_type = linear_nii(LinearModality.FLAIR, config.use_uncropped_image) - - elif preprocessing == Preprocessing.PET_LINEAR: - file_type = pet_linear_nii( - config.tracer, - config.suvr_reference_region, - config.use_uncropped_image, - ) - elif preprocessing == Preprocessing.DWI_DTI: - file_type = dwi_dti( - config.dti_measure, - config.dti_space, - ) - elif preprocessing == Preprocessing.CUSTOM: - file_type = { - "pattern": f"*{config.custom_suffix}", - "description": "Custom suffix", - } - # custom_suffix["use_uncropped_image"] = None - - return mod_subfolder, file_type - - ############ # SLICE # ############ diff --git a/clinicadl/quality_check/t1_linear/utils.py b/clinicadl/quality_check/t1_linear/utils.py index a8472ed3d..5130582c1 100755 --- a/clinicadl/quality_check/t1_linear/utils.py +++ b/clinicadl/quality_check/t1_linear/utils.py @@ -8,8 +8,8 @@ import torch from torch.utils.data import Dataset +from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type from clinicadl.prepare_data.prepare_data_config import PrepareDataImageConfig -from clinicadl.prepare_data.prepare_data_utils import compute_folder_and_file_type from clinicadl.utils.clinica_utils import clinicadl_file_reader, linear_nii from clinicadl.utils.enum import LinearModality, Preprocessing diff --git a/clinicadl/train/resume.py b/clinicadl/train/resume.py index f80f5791b..b4ec16ba8 100644 --- a/clinicadl/train/resume.py +++ b/clinicadl/train/resume.py @@ -6,9 +6,9 @@ from logging import getLogger from pathlib import Path -from clinicadl import MapsManager from clinicadl.train.tasks_utils import create_training_config from clinicadl.trainer.trainer import Trainer +from clinicadl.utils.maps_manager import MapsManager def replace_arg(options, key_name, value): diff --git a/clinicadl/transforms/transforms.py b/clinicadl/transforms/transforms.py index 8bab28960..ef8ba29eb 100644 --- a/clinicadl/transforms/transforms.py +++ b/clinicadl/transforms/transforms.py @@ -1,43 +1,15 @@ # coding: utf8 -import abc from logging import getLogger -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np -import pandas as pd import torch import torchio as tio import torchvision.transforms as transforms -from torch.utils.data import Dataset - -from clinicadl.prepare_data.prepare_data_config import ( - PrepareDataConfig, - PrepareDataImageConfig, - PrepareDataPatchConfig, - PrepareDataROIConfig, - PrepareDataSliceConfig, -) -from clinicadl.prepare_data.prepare_data_utils import ( - PATTERN_DICT, - TEMPLATE_DICT, - compute_discarded_slices, - compute_folder_and_file_type, - extract_patch_path, - extract_patch_tensor, - extract_roi_path, - extract_roi_tensor, - extract_slice_path, - extract_slice_tensor, - find_mask_path, -) -from clinicadl.utils.enum import Preprocessing, SliceDirection, SliceMode + from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLCAPSError, ClinicaDLConfigurationError, - ClinicaDLTSVError, ) logger = getLogger("clinicadl") diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/utils/maps_manager/maps_manager.py index 4cbde6c0d..3662b695a 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/utils/maps_manager/maps_manager.py @@ -10,7 +10,7 @@ import torch.distributed as dist from torch.cuda.amp import autocast -from clinicadl.caps_dataset.data_utils import ( +from clinicadl.caps_dataset.data import ( return_dataset, ) from clinicadl.preprocessing.preprocessing import path_encoder diff --git a/clinicadl/utils/maps_manager/maps_manager_utils.py b/clinicadl/utils/maps_manager/maps_manager_utils.py index e711342e5..524659ce1 100644 --- a/clinicadl/utils/maps_manager/maps_manager_utils.py +++ b/clinicadl/utils/maps_manager/maps_manager_utils.py @@ -4,7 +4,7 @@ import toml -from clinicadl.prepare_data.prepare_data_utils import compute_folder_and_file_type +from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type from clinicadl.preprocessing.preprocessing import path_decoder, path_encoder diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 2f400ffc3..42967c929 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -6,8 +6,8 @@ import pandas as pd -from clinicadl import MapsManager from clinicadl.utils.exceptions import MAPSError +from clinicadl.utils.maps_manager import MapsManager def meta_maps_analysis(launch_dir: Path, evaluation_metric="loss"): diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 633cdaf67..61bf50265 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -10,13 +10,9 @@ import pytest -from clinicadl.prepare_data.prepare_data_config import ( - PrepareDataConfig, - PrepareDataImageConfig, - PrepareDataPatchConfig, - PrepareDataROIConfig, - PrepareDataSliceConfig, -) +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig, get_modality +from clinicadl.config.config.modality import CustomModalityConfig, PETModalityConfig +from clinicadl.preprocessing.config import PreprocessingROIConfig from clinicadl.utils.enum import ( ExtractionMethod, Preprocessing, @@ -60,9 +56,11 @@ def test_prepare_data(cmdopt, tmp_path, test_name): shutil.rmtree(tmp_out_dir / "caps_image_flair") shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_image_flair") - config = PrepareDataImageConfig( + config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.IMAGE, + preprocessing_type=Preprocessing.T1_LINEAR, + preprocessing=Preprocessing.T1_LINEAR, caps_directory=tmp_out_dir / "caps_image", - preprocessing_cls=Preprocessing.T1_LINEAR, ) elif test_name == "patch": @@ -74,9 +72,11 @@ def test_prepare_data(cmdopt, tmp_path, test_name): shutil.rmtree(tmp_out_dir / "caps_patch_flair") shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_patch_flair") - config = PrepareDataPatchConfig( + config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.PATCH, + preprocessing_type=Preprocessing.T1_LINEAR, + preprocessing=Preprocessing.T1_LINEAR, caps_directory=tmp_out_dir / "caps_patch", - preprocessing_cls=Preprocessing.T1_LINEAR, ) elif test_name == "slice": @@ -88,10 +88,13 @@ def test_prepare_data(cmdopt, tmp_path, test_name): shutil.rmtree(tmp_out_dir / "caps_slice_flair") shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_slice_flair") - config = PrepareDataSliceConfig( + config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.SLICE, + preprocessing_type=Preprocessing.T1_LINEAR, + preprocessing=Preprocessing.T1_LINEAR, caps_directory=tmp_out_dir / "caps_slice", - preprocessing_cls=Preprocessing.T1_LINEAR, ) + elif test_name == "roi": if (tmp_out_dir / "caps_roi").is_dir(): shutil.rmtree(tmp_out_dir / "caps_roi") @@ -100,11 +103,15 @@ def test_prepare_data(cmdopt, tmp_path, test_name): if (tmp_out_dir / "caps_roi_flair").is_dir(): shutil.rmtree(tmp_out_dir / "caps_roi_flair") shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_roi_flair") - config = PrepareDataROIConfig( - caps_directory=tmp_out_dir / "caps_roi", - preprocessing_cls=Preprocessing.T1_LINEAR, + + config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.ROI, + preprocessing_type=Preprocessing.T1_LINEAR, + preprocessing=Preprocessing.T1_LINEAR, + caps_directory=tmp_out_dir / "caps_image", roi_list=["rightHippocampusBox", "leftHippocampusBox"], ) + else: print(f"Test {test_name} not available.") assert 0 @@ -113,52 +120,59 @@ def test_prepare_data(cmdopt, tmp_path, test_name): def run_test_prepare_data( - input_dir, ref_dir, out_dir, test_name: str, config: PrepareDataConfig + input_dir, ref_dir, out_dir, test_name: str, config: CapsDatasetConfig ): modalities = ["t1-linear", "pet-linear", "flair-linear"] uncropped_image = [True, False] acquisition_label = ["18FAV45", "11CPIB"] - config.save_features = True + config.preprocessing.save_features = True for modality in modalities: - config.preprocessing = Preprocessing(modality) + config.preprocessing.preprocessing = Preprocessing(modality) + config.modality = get_modality(Preprocessing(modality))() if modality == "pet-linear": for acq in acquisition_label: - config.tracer = Tracer(acq) - config.suvr_reference_region = SUVRReferenceRegions("pons2") - config.use_uncropped_image = False - config.extract_json = f"{modality}-{acq}_mode-{test_name}.json" + assert isinstance(config.modality, PETModalityConfig) + config.modality.tracer = Tracer(acq) + config.modality.suvr_reference_region = SUVRReferenceRegions("pons2") + config.preprocessing.use_uncropped_image = False + config.preprocessing.extract_json = ( + f"{modality}-{acq}_mode-{test_name}.json" + ) tsv_file = join(input_dir, f"pet_{acq}.tsv") mode = test_name extract_generic(out_dir, mode, tsv_file, config) elif modality == "custom": - config.use_uncropped_image = True - config.custom_suffix = ( + assert isinstance(config.modality, CustomModalityConfig) + config.preprocessing.use_uncropped_image = True + config.modality.custom_suffix = ( "graymatter_space-Ixi549Space_modulated-off_probability.nii.gz" ) - if isinstance(config, PrepareDataROIConfig): - config.roi_custom_template = "Ixi549Space" - config.extract_json = f"{modality}_mode-{test_name}.json" + if isinstance(config.preprocessing, PreprocessingROIConfig): + config.preprocessing.roi_custom_template = "Ixi549Space" + config.preprocessing.extract_json = f"{modality}_mode-{test_name}.json" tsv_file = input_dir / "subjects.tsv" mode = test_name extract_generic(out_dir, mode, tsv_file, config) elif modality == "t1-linear": for flag in uncropped_image: - config.use_uncropped_image = flag - config.extract_json = ( + config.preprocessing.use_uncropped_image = flag + config.preprocessing.extract_json = ( f"{modality}_crop-{not flag}_mode-{test_name}.json" ) mode = test_name extract_generic(out_dir, mode, None, config) elif modality == "flair-linear": - config.caps_directory = Path(str(config.caps_directory) + "_flair") - config.save_features = False + config.data.caps_directory = Path( + str(config.data.caps_directory) + "_flair" + ) + config.preprocessing.save_features = False for flag in uncropped_image: - config.use_uncropped_image = flag - config.extract_json = ( + config.preprocessing.use_uncropped_image = flag + config.preprocessing.extract_json = ( f"{modality}_crop-{not flag}_mode-{test_name}.json" ) mode = f"{test_name}_flair" @@ -178,10 +192,10 @@ def run_test_prepare_data( ) -def extract_generic(out_dir, mode, tsv_file, config: PrepareDataConfig): +def extract_generic(out_dir, mode, tsv_file, config: CapsDatasetConfig): from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData - config.caps_directory = out_dir / f"caps_{mode}" - config.tsv_file = tsv_file - config.n_proc = 1 + config.data.caps_directory = out_dir / f"caps_{mode}" + config.data.data_tsv = tsv_file + config.dataloader.n_proc = 1 DeepLearningPrepareData(config) diff --git a/tests/test_resume.py b/tests/test_resume.py index 5827bda0f..44af2f6d5 100644 --- a/tests/test_resume.py +++ b/tests/test_resume.py @@ -6,7 +6,7 @@ import pytest -from clinicadl import MapsManager +from clinicadl.utils.maps_manager import MapsManager from .testing_tools import modify_maps