From e80ecc84346f099a098471d93e85c8410bbe0ad2 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 28 Nov 2024 15:30:52 +0100 Subject: [PATCH 01/10] first commit --- clinicadl/API/dataset_test.py | 167 ++++ clinicadl/dataset/caps_dataset.py | 817 ------------------ clinicadl/dataset/caps_dataset_config.py | 127 --- clinicadl/dataset/caps_dataset_utils.py | 193 ----- clinicadl/dataset/caps_reader.py | 62 -- clinicadl/dataset/concat.py | 6 - clinicadl/dataset/config/__init__.py | 13 + clinicadl/dataset/config/data.py | 77 ++ clinicadl/dataset/config/extraction.py | 67 -- clinicadl/dataset/config/file_type.py | 47 + clinicadl/dataset/config/preprocessing.py | 190 +++- clinicadl/dataset/data_config.py | 164 ---- clinicadl/dataset/dataloader_config.py | 18 - .../__init__.py => datasets/___init__.py} | 0 clinicadl/dataset/datasets/caps_dataset.py | 530 ++++++++++++ clinicadl/dataset/datasets/concat.py | 51 ++ .../dataset/prepare_data/prepare_data.py | 230 ----- .../prepare_data/prepare_data_utils.py | 442 ---------- clinicadl/dataset/readers/__init__.py | 2 + clinicadl/dataset/readers/bids_reader.py | 157 ++++ clinicadl/dataset/readers/caps_reader.py | 311 +++++++ .../dataset/readers/multi_caps_reader.py | 51 ++ clinicadl/dataset/readers/reader.py | 181 ++++ clinicadl/dataset/utils.py | 337 +++++--- clinicadl/transforms/extraction/__init__.py | 5 + clinicadl/transforms/extraction/base.py | 149 ++++ clinicadl/transforms/extraction/image.py | 121 +++ clinicadl/transforms/extraction/patch.py | 168 ++++ clinicadl/transforms/extraction/roi.py | 357 ++++++++ clinicadl/transforms/extraction/slice.py | 153 ++++ clinicadl/transforms/transforms.py | 14 - 31 files changed, 2933 insertions(+), 2274 deletions(-) create mode 100644 clinicadl/API/dataset_test.py delete mode 100644 clinicadl/dataset/caps_dataset.py delete mode 100644 clinicadl/dataset/caps_dataset_config.py delete mode 100644 clinicadl/dataset/caps_dataset_utils.py delete mode 100644 clinicadl/dataset/caps_reader.py delete mode 100644 clinicadl/dataset/concat.py create mode 100644 clinicadl/dataset/config/data.py delete mode 100644 clinicadl/dataset/config/extraction.py create mode 100644 clinicadl/dataset/config/file_type.py delete mode 100644 clinicadl/dataset/data_config.py delete mode 100644 clinicadl/dataset/dataloader_config.py rename clinicadl/dataset/{prepare_data/__init__.py => datasets/___init__.py} (100%) create mode 100644 clinicadl/dataset/datasets/caps_dataset.py create mode 100644 clinicadl/dataset/datasets/concat.py delete mode 100644 clinicadl/dataset/prepare_data/prepare_data.py delete mode 100644 clinicadl/dataset/prepare_data/prepare_data_utils.py create mode 100644 clinicadl/dataset/readers/__init__.py create mode 100644 clinicadl/dataset/readers/bids_reader.py create mode 100644 clinicadl/dataset/readers/caps_reader.py create mode 100644 clinicadl/dataset/readers/multi_caps_reader.py create mode 100644 clinicadl/dataset/readers/reader.py create mode 100644 clinicadl/transforms/extraction/__init__.py create mode 100644 clinicadl/transforms/extraction/base.py create mode 100644 clinicadl/transforms/extraction/image.py create mode 100644 clinicadl/transforms/extraction/patch.py create mode 100644 clinicadl/transforms/extraction/roi.py create mode 100644 clinicadl/transforms/extraction/slice.py delete mode 100644 clinicadl/transforms/transforms.py diff --git a/clinicadl/API/dataset_test.py b/clinicadl/API/dataset_test.py new file mode 100644 index 000000000..017920867 --- /dev/null +++ b/clinicadl/API/dataset_test.py @@ -0,0 +1,167 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + PreprocessingFlair, + PreprocessingPET, + PreprocessingT1, +) +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.datasets.concat import ConcatDataset +from clinicadl.dataset.transforms.extraction import ROI, Image, Patch, Slice +from clinicadl.dataset.transforms.transforms import Transforms +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv + +sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv") +sub_ses_pet_45 = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv" +) +sub_ses_flair = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_flair.tsv" +) +sub_ses_pet_11 = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_11CPIB.tsv" +) + +caps_directory = Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps" +) # output of clinica pipelines + +preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") +preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2") + +preprocessing_t1 = PreprocessingT1() +preprocessing_flair = PreprocessingFlair() + + +transforms_patch = Transforms( + object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)], + image_augmentation=[transforms.RandomMotion()], + extraction=Patch(patch_size=60), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], + object_transforms=[transforms.RandomMotion()], +) # not mandatory + +transforms_slice = Transforms(extraction=Slice()) + +transforms_roi = Transforms( + object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)], + object_transforms=[transforms.RandomMotion()], + extraction=ROI( + roi_list=["leftHippocampusBox", "rightHippocampusBox"], + roi_mask_location=Path( + "/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym" + ), + roi_crop_input=True, + ), +) + +transforms_image = Transforms( + image_augmentation=[transforms.RandomMotion()], + extraction=Image(), + image_transforms=[transforms.Blur((0.5, 0.6, 0.3))], +) + + +print("Pet 45 and Patch ") +dataset_pet_45_patch = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_pet_45, + preprocessing=preprocessing_pet_45, + transforms=transforms_patch, +) +dataset_pet_45_patch.prepare_data(n_proc=2) + +print(dataset_pet_45_patch) +print(dataset_pet_45_patch.__len__()) +print(dataset_pet_45_patch._get_meta_data(3)) +print(dataset_pet_45_patch._get_meta_data(80)) +# print(dataset_pet_45_patch._get_full_image()) +print(dataset_pet_45_patch.__getitem__(80).elem_idx) +print(dataset_pet_45_patch.elem_per_image) + +dataset_pet_45_patch.caps_reader._write_caps_json( + transforms_patch, preprocessing_pet_45, sub_ses_pet_45, name="tfsdklsqfh" +) + + +print("Pet 11 and ROI ") + +dataset_pet_11_roi = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_pet_11, + preprocessing=preprocessing_pet_11, + transforms=transforms_roi, +) +dataset_pet_11_roi.prepare_data( + n_proc=2 +) # to extract the tensor of the PET file this time + +print(dataset_pet_11_roi) +print(dataset_pet_11_roi.__len__()) +print(dataset_pet_11_roi._get_meta_data(0)) +print(dataset_pet_11_roi._get_meta_data(1)) +# print(dataset_pet_11_roi._get_full_image()) +print(dataset_pet_11_roi.__getitem__(1).elem_idx) +print(dataset_pet_11_roi.elem_per_image) + + +print("T1 and image ") + +dataset_t1_image = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_t1, + preprocessing=preprocessing_t1, + transforms=transforms_image, +) +dataset_t1_image.prepare_data( + n_proc=2 +) # to extract the tensor of the PET file this time + +print(dataset_t1_image) +print(dataset_t1_image.__len__()) +print(dataset_t1_image._get_meta_data(3)) +print(dataset_t1_image._get_meta_data(5)) +# print(dataset_t1_image._get_full_image()) +print(dataset_t1_image.__getitem__(5).elem_idx) +print(dataset_t1_image.elem_per_image) + + +print("Flair and slice ") + +dataset_flair_slice = CapsDataset( + caps_directory=caps_directory, + data=sub_ses_flair, + preprocessing=preprocessing_flair, + transforms=transforms_slice, +) +dataset_flair_slice.prepare_data( + n_proc=2 +) # to extract the tensor of the PET file this time + +print(dataset_flair_slice) +print(dataset_flair_slice.__len__()) +print(dataset_flair_slice._get_meta_data(3)) +print(dataset_flair_slice._get_meta_data(80)) +# print(dataset_flair_slice._get_full_image()) +print(dataset_flair_slice.__getitem__(80).elem_idx) +print(dataset_flair_slice.elem_per_image) + + +lity_multi_extract = ConcatDataset( + [ + dataset_t1, + dataset_pet, + ] +) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention diff --git a/clinicadl/dataset/caps_dataset.py b/clinicadl/dataset/caps_dataset.py deleted file mode 100644 index dec004b0a..000000000 --- a/clinicadl/dataset/caps_dataset.py +++ /dev/null @@ -1,817 +0,0 @@ -# coding: utf8 -# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? -import abc -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.extraction import ( - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.dataset.prepare_data.prepare_data_utils import ( - compute_discarded_slices, - extract_patch_path, - extract_patch_tensor, - extract_roi_path, - extract_roi_tensor, - extract_slice_path, - extract_slice_tensor, - find_mask_path, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ( - Pattern, - Preprocessing, - SliceDirection, - SliceMode, - Template, -) -from clinicadl.utils.exceptions import ( - ClinicaDLCAPSError, - ClinicaDLTSVError, -) - -logger = getLogger("clinicadl") - - -################################# -# Datasets loaders -################################# -class CapsDataset(Dataset): - """Abstract class for all derived CapsDatasets.""" - - def __init__( - self, - config: CapsDatasetConfig, - label_presence: bool, - preprocessing_dict: Dict[str, Any], - ): - self.label_presence = label_presence - self.eval_mode = False - self.config = config - self.preprocessing_dict = preprocessing_dict - - if not hasattr(self, "elem_index"): - raise AttributeError( - "Child class of CapsDataset must set elem_index attribute." - ) - if not hasattr(self, "mode"): - raise AttributeError("Child class of CapsDataset, must set mode attribute.") - - self.df = self.config.data.data_df - mandatory_col = { - "participant_id", - "session_id", - "cohort", - } - if label_presence and self.config.data.label is not None: - mandatory_col.add(self.config.data.label) - - if not mandatory_col.issubset(set(self.df.columns.values)): - raise ClinicaDLTSVError( - f"the data file is not in the correct format." - f"Columns should include {mandatory_col}" - ) - self.elem_per_image = self.num_elem_per_image() - self.size = self[0]["image"].size() - - @property - @abc.abstractmethod - def elem_index(self): - pass - - def label_fn(self, target: Union[str, float, int]) -> Union[float, int, None]: - """ - Returns the label value usable in criterion. - - Args: - target: value of the target. - Returns: - label: value of the label usable in criterion. - """ - # Reconstruction case (no label) - if self.config.data.label is None: - return None - # Regression case (no label code) - elif self.config.data.label_code is None: - return np.float32([target]) - # Classification case (label + label_code dict) - else: - return self.config.data.label_code[str(target)] - - def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: - """ - Returns the label value usable in criterion. - - """ - domain_code = {"t1": 0, "flair": 1} - return domain_code[str(target)] - - def __len__(self) -> int: - return len(self.df) * self.elem_per_image - - def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: - """ - Gets the path to the tensor image (*.pt) - - Args: - participant: ID of the participant. - session: ID of the session. - cohort: Name of the cohort. - Returns: - image_path: path to the tensor containing the whole image. - """ - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - - # Try to find .nii.gz file - try: - folder, file_type = self.config.compute_folder_and_file_type() - - results = clinicadl_file_reader( - [participant], - [session], - self.config.data.caps_dict[cohort], - file_type, - ) - logger.debug(f"clinicadl_file_reader output: {results}") - filepath = Path(results[0][0]) - image_filename = filepath.name.replace(".nii.gz", ".pt") - - image_dir = ( - self.config.data.caps_dict[cohort] - / "subjects" - / participant - / session - / "deeplearning_prepare_data" - / "image_based" - / folder - ) - image_path = image_dir / image_filename - # Try to find .pt file - except ClinicaDLCAPSError: - folder, file_type = self.config.compute_folder_and_file_type() - file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt") - results = clinicadl_file_reader( - [participant], - [session], - self.config.data.caps_dict[cohort], - file_type, - ) - filepath = results[0] - image_path = Path(filepath[0]) - - return image_path - - def _get_meta_data( - self, idx: int - ) -> Tuple[str, str, str, Union[float, int, None], int]: - """ - Gets all meta data necessary to compute the path with _get_image_path - - Args: - idx (int): row number of the meta-data contained in self.df - Returns: - participant (str): ID of the participant. - session (str): ID of the session. - cohort (str): Name of the cohort. - elem_index (int): Index of the part of the image. - label (str or float or int): value of the label to be used in criterion. - """ - image_idx = idx // self.elem_per_image - participant = self.df.at[image_idx, "participant_id"] - session = self.df.at[image_idx, "session_id"] - cohort = self.df.at[image_idx, "cohort"] - - if self.elem_index is None: - elem_idx = idx % self.elem_per_image - else: - elem_idx = self.elem_index - if self.label_presence and self.config.data.label is not None: - target = self.df.at[image_idx, self.config.data.label] - label = self.label_fn(target) - else: - label = -1 - - if "domain" in self.df.columns: - domain = self.df.at[image_idx, "domain"] - domain = self.domain_fn(domain) - else: - domain = "" # TO MODIFY - return participant, session, cohort, elem_idx, label, domain - - def _get_full_image(self) -> torch.Tensor: - """ - Allows to get the an example of the image mode corresponding to the dataset. - Useful to compute the number of elements if mode != image. - - Returns: - image tensor of the full image first image. - """ - import nibabel as nib - - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - - participant_id = self.df.at[0, "participant_id"] - session_id = self.df.at[0, "session_id"] - cohort = self.df.at[0, "cohort"] - - try: - image_path = self._get_image_path(participant_id, session_id, cohort) - image = torch.load(image_path, weights_only=True) - except IndexError: - file_type = self.config.extraction.file_type - results = clinicadl_file_reader( - [participant_id], - [session_id], - self.config.data.caps_dict[cohort], - file_type, - ) - image_nii = nib.loadsave.load(results[0]) - image_np = image_nii.get_fdata() - image = ToTensor()(image_np) - - return image - - @abc.abstractmethod - def __getitem__(self, idx: int) -> Dict[str, Any]: - """ - Gets the sample containing all the information needed for training and testing tasks. - - Args: - idx: row number of the meta-data contained in self.df - Returns: - dictionary with following items: - - "image" (torch.Tensor): the input given to the model, - - "label" (int or float): the label used in criterion, - - "participant_id" (str): ID of the participant, - - "session_id" (str): ID of the session, - - f"{self.mode}_id" (int): number of the element, - - "image_path": path to the image loaded in CAPS. - - """ - pass - - @abc.abstractmethod - def num_elem_per_image(self) -> int: - """Computes the number of elements per image based on the full image.""" - pass - - def eval(self): - """Put the dataset on evaluation mode (data augmentation is not performed).""" - self.eval_mode = True - return self - - def train(self): - """Put the dataset on training mode (data augmentation is performed).""" - self.eval_mode = False - return self - - -class CapsDatasetImage(CapsDataset): - """Dataset of MRI organized in a CAPS folder.""" - - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - - self.mode = "image" - self.config = config - self.label_presence = label_presence - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return None - - def __getitem__(self, idx): - participant, session, cohort, _, label, domain = self._get_meta_data(idx) - - image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path, weights_only=True) - - train_trf, trf = self.config.transforms.get_transforms() - - image = trf(image) - if self.config.transforms.train_transformations and not self.eval_mode: - image = train_trf(image) - - sample = { - "image": image, - "label": label, - "participant_id": participant, - "session_id": session, - "image_id": 0, - "image_path": image_path.as_posix(), - "domain": domain, - } - - return sample - - def num_elem_per_image(self): - return 1 - - -class CapsDatasetPatch(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - patch_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - """ - self.patch_index = patch_index - self.mode = "patch" - self.config = config - self.label_presence = label_presence - - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.patch_index - - def __getitem__(self, idx): - participant, session, cohort, patch_idx, label, domain = self._get_meta_data( - idx - ) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - patch_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - patch_filename = extract_patch_path( - image_path, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, - ) - patch_tensor = torch.load( - Path(patch_dir).resolve() / patch_filename, weights_only=True - ) - - else: - image = torch.load(image_path, weights_only=True) - patch_tensor = extract_patch_tensor( - image, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, - ) - - train_trf, trf = self.config.transforms.get_transforms() - patch_tensor = trf(patch_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - patch_tensor = train_trf(patch_tensor) - - sample = { - "image": patch_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "patch_id": patch_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - image = self._get_full_image() - - patches_tensor = ( - image.unfold( - 1, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 2, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 3, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .contiguous() - ) - patches_tensor = patches_tensor.view( - -1, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - ) - num_patches = patches_tensor.shape[0] - return num_patches - - -class CapsDatasetRoi(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - roi_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - roi_index: If a value is given the same region will be extracted for each image. - else the dataset will load all the regions possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - self.roi_index = roi_index - self.mode = "roi" - self.config = config - self.label_presence = label_presence - self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors( - self.config.data.caps_directory, preprocessing_dict - ) - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.roi_index - - def __getitem__(self, idx): - participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.roi_list is None: - raise NotImplementedError( - "Default regions are not available anymore in ClinicaDL. " - "Please define appropriate masks and give a roi_list." - ) - - if self.config.extraction.save_features: - mask_path = self.mask_paths[roi_idx] - roi_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - roi_filename = extract_roi_path( - image_path, mask_path, self.config.extraction.roi_uncrop_output - ) - roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) - - else: - image = torch.load(image_path, weights_only=True) - mask_array = self.mask_arrays[roi_idx] - roi_tensor = extract_roi_tensor( - image, mask_array, self.config.extraction.uncropped_roi - ) - - train_trf, trf = self.config.transforms.get_transforms() - - roi_tensor = trf(roi_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - roi_tensor = train_trf(roi_tensor) - - sample = { - "image": roi_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "roi_id": roi_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - if self.config.extraction.roi_list is None: - return 2 - else: - return len(self.config.extraction.roi_list) - - def _get_mask_paths_and_tensors( - self, - caps_directory: Path, - preprocessing_dict: Dict[str, Any], - ) -> Tuple[List[str], List]: - """Loads the masks necessary to regions extraction""" - import nibabel as nib - - caps_dict = self.config.data.caps_dict - if len(caps_dict) > 1: - caps_directory = caps_dict[next(iter(caps_dict))] - logger.warning( - f"The equality of masks is not assessed for multi-cohort training. " - f"The masks stored in {caps_directory} will be used." - ) - - try: - preprocessing_ = Preprocessing(preprocessing_dict["preprocessing"]) - except NotImplementedError: - print( - f"Template of preprocessing {preprocessing_dict['preprocessing']} " - f"is not defined." - ) - # Find template name and pattern - if preprocessing_.value == "custom": - template_name = preprocessing_dict["roi_custom_template"] - if template_name is None: - raise ValueError( - "Please provide a name for the template when preprocessing is `custom`." - ) - - pattern = preprocessing_dict["roi_custom_mask_pattern"] - if pattern is None: - raise ValueError( - "Please provide a pattern for the masks when preprocessing is `custom`." - ) - - else: - for template_ in Template: - if preprocessing_.name == template_.name: - template_name = template_ - - for pattern_ in Pattern: - if preprocessing_.name == pattern_.name: - pattern = pattern_ - - mask_location = caps_directory / "masks" / f"tpl-{template_name}" - - mask_paths, mask_arrays = list(), list() - for roi in self.config.extraction.roi_list: - logger.info(f"Find mask for roi {roi}.") - mask_path, desc = find_mask_path(mask_location, roi, pattern, True) - if mask_path is None: - raise FileNotFoundError(desc) - mask_nii = nib.loadsave.load(mask_path) - mask_paths.append(Path(mask_path)) - mask_arrays.append(mask_nii.get_fdata()) - - return mask_paths, mask_arrays - - -class CapsDatasetSlice(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - slice_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - slice_index: If a value is given the same slice will be extracted for each image. - else the dataset will load all the slices possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - """ - self.slice_index = slice_index - self.mode = "slice" - self.config = config - self.label_presence = label_presence - self.preprocessing_dict = preprocessing_dict - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.slice_index - - def __getitem__(self, idx): - participant, session, cohort, slice_idx, label, domain = self._get_meta_data( - idx - ) - slice_idx = slice_idx + self.config.extraction.discarded_slices[0] - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - slice_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - slice_filename = extract_slice_path( - image_path, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - slice_tensor = torch.load( - Path(slice_dir) / slice_filename, weights_only=True - ) - - else: - image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path, weights_only=True) - slice_tensor = extract_slice_tensor( - image, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - - train_trf, trf = self.config.transforms.get_transforms() - - slice_tensor = trf(slice_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - slice_tensor = train_trf(slice_tensor) - - sample = { - "image": slice_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "slice_id": slice_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - if self.config.extraction.num_slices is not None: - return self.config.extraction.num_slices - - image = self._get_full_image() - return ( - image.size(int(self.config.extraction.slice_direction) + 1) - - self.config.extraction.discarded_slices[0] - - self.config.extraction.discarded_slices[1] - ) - - -def return_dataset( - input_dir: Path, - data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, - cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, -) -> CapsDataset: - """ - Return appropriate Dataset according to given options. - Args: - input_dir: path to a directory containing a CAPS structure. - data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied during training only. - all_transformations: Optional transform to be applied during training and evaluation. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - cnn_index: Index of the CNN in a multi-CNN paradigm (optional). - label_presence: If True the diagnosis will be extracted from the given DataFrame. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - Returns: - the corresponding dataset. - """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": - raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." - ) - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - preprocessing_type=preprocessing_dict["preprocessing"], - preprocessing=preprocessing_dict["preprocessing"], - extraction=preprocessing_dict["mode"], - caps_directory=input_dir, - data_df=data_df, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - config.transforms = transforms_config - - if preprocessing_dict["mode"] == "image": - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetImage( - config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "patch": - assert isinstance(config.extraction, ExtractionPatchConfig) - config.extraction.patch_size = preprocessing_dict["patch_size"] - config.extraction.stride_size = preprocessing_dict["stride_size"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetPatch( - config, - patch_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "roi": - assert isinstance(config.extraction, ExtractionROIConfig) - config.extraction.roi_list = preprocessing_dict["roi_list"] - config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetRoi( - config, - roi_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "slice": - assert isinstance(config.extraction, ExtractionSliceConfig) - config.extraction.slice_direction = SliceDirection( - str(preprocessing_dict["slice_direction"]) - ) - config.extraction.slice_mode = SliceMode(preprocessing_dict["slice_mode"]) - config.extraction.discarded_slices = compute_discarded_slices( - preprocessing_dict["discarded_slices"] - ) - config.extraction.num_slices = ( - None - if "num_slices" not in preprocessing_dict - else preprocessing_dict["num_slices"] - ) - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetSlice( - config, - slice_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - else: - raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." - ) diff --git a/clinicadl/dataset/caps_dataset_config.py b/clinicadl/dataset/caps_dataset_config.py deleted file mode 100644 index 0eac3ffd3..000000000 --- a/clinicadl/dataset/caps_dataset_config.py +++ /dev/null @@ -1,127 +0,0 @@ -from pathlib import Path -from typing import Optional, Tuple, Union - -from pydantic import BaseModel, ConfigDict - -from clinicadl.dataset.config import extraction -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.data_config import DataConfig -from clinicadl.dataset.dataloader_config import DataLoaderConfig -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ExtractionMethod, Preprocessing -from clinicadl.utils.iotools.clinica_utils import FileType - - -def get_extraction(extract_method: ExtractionMethod): - if extract_method == ExtractionMethod.ROI: - return extraction.ExtractionROIConfig - elif extract_method == ExtractionMethod.SLICE: - return extraction.ExtractionSliceConfig - elif extract_method == ExtractionMethod.IMAGE: - return extraction.ExtractionImageConfig - elif extract_method == ExtractionMethod.PATCH: - return extraction.ExtractionPatchConfig - else: - raise ValueError(f"Preprocessing {extract_method.value} is not implemented.") - - -def get_preprocessing(preprocessing_type: Preprocessing): - if preprocessing_type == Preprocessing.T1_LINEAR: - return T1PreprocessingConfig - elif preprocessing_type == Preprocessing.PET_LINEAR: - return PETPreprocessingConfig - elif preprocessing_type == Preprocessing.FLAIR_LINEAR: - return FlairPreprocessingConfig - elif preprocessing_type == Preprocessing.CUSTOM: - return CustomPreprocessingConfig - elif preprocessing_type == Preprocessing.DWI_DTI: - return DTIPreprocessingConfig - else: - raise ValueError( - f"Preprocessing {preprocessing_type.value} is not implemented." - ) - - -class CapsDatasetConfig(BaseModel): - """Config class for CapsDataset object. - - caps_directory, preprocessing_json, extract_method, preprocessing - are arguments that must be passed by the user. - - transforms isn't optional because there is always at least one transform (NanRemoval) - """ - - data: DataConfig - dataloader: DataLoaderConfig - extraction: extraction.ExtractionConfig - preprocessing: PreprocessingConfig - transforms: TransformsConfig - - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @classmethod - def from_preprocessing_and_extraction_method( - cls, - preprocessing_type: Union[str, Preprocessing], - extraction: Union[str, ExtractionMethod], - **kwargs, - ): - return cls( - data=DataConfig(**kwargs), - dataloader=DataLoaderConfig(**kwargs), - preprocessing=get_preprocessing(Preprocessing(preprocessing_type))( - **kwargs - ), - extraction=get_extraction(ExtractionMethod(extraction))(**kwargs), - transforms=TransformsConfig(**kwargs), - ) - - def compute_folder_and_file_type( - self, from_bids: Optional[Path] = None - ) -> Tuple[str, FileType]: - preprocessing = self.preprocessing.preprocessing - if from_bids is not None: - if isinstance(self.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(self.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(self.preprocessing, T1PreprocessingConfig) or isinstance( - self.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(self.preprocessing) - elif isinstance(self.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type diff --git a/clinicadl/dataset/caps_dataset_utils.py b/clinicadl/dataset/caps_dataset_utils.py deleted file mode 100644 index b54ba373d..000000000 --- a/clinicadl/dataset/caps_dataset_utils.py +++ /dev/null @@ -1,193 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.utils.enum import Preprocessing -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import FileType - - -def compute_folder_and_file_type( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -) -> Tuple[str, FileType]: - preprocessing = config.preprocessing.preprocessing - if from_bids is not None: - if isinstance(config.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(config.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(config.preprocessing, T1PreprocessingConfig) or isinstance( - config.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(config.preprocessing) - elif isinstance(config.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type - - -def find_file_type(config: CapsDatasetConfig) -> FileType: - if isinstance(config.preprocessing, T1PreprocessingConfig): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - if ( - config.preprocessing.tracer is None - or config.preprocessing.suvr_reference_region is None - ): - raise ClinicaDLArgumentError( - "`tracer` and `suvr_reference_region` must be defined " - "when using `pet-linear` preprocessing." - ) - file_type = pet_linear_nii(config.preprocessing) - else: - raise NotImplementedError( - f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" - ) - - return file_type - - -def read_json(json_path: Path) -> Dict[str, Any]: - """ - Ensures retro-compatibility between the different versions of ClinicaDL. - - Parameters - ---------- - json_path: Path - path to the JSON file summing the parameters of a MAPS. - - Returns - ------- - A dictionary of training parameters. - """ - from clinicadl.utils.iotools.utils import path_decoder - - with json_path.open(mode="r") as f: - parameters = json.load(f, object_hook=path_decoder) - # Types of retro-compatibility - # Change arg name: ex network --> model - # Change arg value: ex for preprocessing: mni --> t1-extensive - # New arg with default hard-coded value --> discarded_slice --> 20 - retro_change_name = { - "model": "architecture", - "multi": "multi_network", - "minmaxnormalization": "normalize", - "num_workers": "n_proc", - "mode": "extract_method", - } - - retro_add = { - "optimizer": "Adam", - "loss": None, - } - - for old_name, new_name in retro_change_name.items(): - if old_name in parameters: - parameters[new_name] = parameters[old_name] - del parameters[old_name] - - for name, value in retro_add.items(): - if name not in parameters: - parameters[name] = value - - if "extract_method" in parameters: - parameters["mode"] = parameters["extract_method"] - # Value changes - if "use_cpu" in parameters: - parameters["gpu"] = not parameters["use_cpu"] - del parameters["use_cpu"] - if "nondeterministic" in parameters: - parameters["deterministic"] = not parameters["nondeterministic"] - del parameters["nondeterministic"] - - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - - from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=parameters["mode"], - preprocessing_type=parameters["preprocessing"], - **parameters, - ) - if "file_type" not in parameters["preprocessing_dict"]: - _, file_type = compute_folder_and_file_type(config) - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() - - return parameters diff --git a/clinicadl/dataset/caps_reader.py b/clinicadl/dataset/caps_reader.py deleted file mode 100644 index 80435401a..000000000 --- a/clinicadl/dataset/caps_reader.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path -from typing import Optional - -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.dataset.config.extraction import ( - ExtractionConfig, - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.dataset.config.preprocessing import PreprocessingConfig -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.transforms.config import TransformsConfig - - -class CapsReader: - def __init__(self, caps_directory: Path): - """TO COMPLETE""" - pass - - def get_dataset( - self, - extraction: ExtractionConfig, - preprocessing: PreprocessingConfig, - sub_ses_tsv: Path, - transforms: TransformsConfig, - ) -> CapsDataset: - return CapsDataset(extraction, preprocessing, sub_ses_tsv, transforms) - - def get_preprocessing(self, preprocessing: str) -> PreprocessingConfig: - """TO COMPLETE""" - - return PreprocessingConfig() - - def extract_slice( - self, preprocessing: PreprocessingConfig, arg_slice: Optional[int] = None - ) -> ExtractionSliceConfig: - """TO COMPLETE""" - - return ExtractionSliceConfig() - - def extract_patch( - self, preprocessing: PreprocessingConfig, arg_patch: Optional[int] = None - ) -> ExtractionPatchConfig: - """TO COMPLETE""" - - return ExtractionPatchConfig() - - def extract_roi( - self, preprocessing: PreprocessingConfig, arg_roi: Optional[int] = None - ) -> ExtractionROIConfig: - """TO COMPLETE""" - - return ExtractionROIConfig() - - def extract_image( - self, preprocessing: PreprocessingConfig, arg_image: Optional[int] = None - ) -> ExtractionImageConfig: - """TO COMPLETE""" - - return ExtractionImageConfig() diff --git a/clinicadl/dataset/concat.py b/clinicadl/dataset/concat.py deleted file mode 100644 index f0b420dfe..000000000 --- a/clinicadl/dataset/concat.py +++ /dev/null @@ -1,6 +0,0 @@ -from clinicadl.dataset.caps_dataset import CapsDataset - - -class ConcatDataset(CapsDataset): - def __init__(self, list_: list[CapsDataset]): - """TO COMPLETE""" diff --git a/clinicadl/dataset/config/__init__.py b/clinicadl/dataset/config/__init__.py index e69de29bb..f1e6c253f 100644 --- a/clinicadl/dataset/config/__init__.py +++ b/clinicadl/dataset/config/__init__.py @@ -0,0 +1,13 @@ +from .file_type import FileType +from .preprocessing import ( + PreprocessingConfig, + PreprocessingCustom, + PreprocessingFlair, + PreprocessingPET, + PreprocessingT1, + PreprocessingT2, +) +from .utils import ( + get_extraction, + get_preprocessing, +) diff --git a/clinicadl/dataset/config/data.py b/clinicadl/dataset/config/data.py new file mode 100644 index 000000000..b0f7b758e --- /dev/null +++ b/clinicadl/dataset/config/data.py @@ -0,0 +1,77 @@ +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import pandas as pd +from pydantic import field_validator + +from clinicadl.utils.config import ClinicaDLConfig + +# from clinicadl.dataset.utils import load_data_test +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLTSVError, +) + +logger = getLogger("clinicadl.data_config") + + +# TODO: check if this file is still useful + + +class DataConfig(ClinicaDLConfig): # TODO : put in data module + """Config class to specify the data. + + caps_directory and preprocessing_json are arguments + that must be passed by the user. + """ + + caps_directory: Optional[Path] = None + baseline: bool = False + mask_path: Optional[Path] = None + data_tsv: Optional[Path] = None + n_subjects: int = 300 + + @field_validator("diagnoses", mode="before") + def validator_diagnoses(cls, v): + """Transforms a list to a tuple.""" + if isinstance(v, list): + return tuple(v) + return v # TODO : check if columns are in tsv + + def create_groupe_df(self): + group_df = None + # if self.data_tsv is not None and self.data_tsv.is_file(): + # group_df = load_data_test( + # self.data_tsv, + # multi_cohort=False, + # ) + return group_df + + def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): + return ( + self.label is not None + and self.label != "" + and self.label != _label + and _label_code == "default" + ) + + def check_label(self, _label: str): + if not self.label: + self.label = _label + + @field_validator("data_tsv", mode="before") + @classmethod + def check_data_tsv(cls, v) -> Path: + if v is not None: + if not isinstance(v, Path): + v = Path(v) + if not v.is_file(): + raise ClinicaDLTSVError( + "The participants_list you gave is not a file. Please give an existing file." + ) + if v.stat().st_size == 0: + raise ClinicaDLTSVError( + "The participants_list you gave is empty. Please give a non-empty file." + ) + return v diff --git a/clinicadl/dataset/config/extraction.py b/clinicadl/dataset/config/extraction.py deleted file mode 100644 index f3619590f..000000000 --- a/clinicadl/dataset/config/extraction.py +++ /dev/null @@ -1,67 +0,0 @@ -from logging import getLogger -from time import time -from typing import List, Optional, Tuple - -from pydantic import BaseModel, ConfigDict, field_validator -from pydantic.types import NonNegativeInt - -from clinicadl.utils.enum import ( - ExtractionMethod, - SliceDirection, - SliceMode, -) -from clinicadl.utils.iotools.clinica_utils import FileType - -logger = getLogger("clinicadl.preprocessing_config") - - -class ExtractionConfig(BaseModel): - """ - Abstract config class for the Extraction procedure. - """ - - extract_method: ExtractionMethod - file_type: Optional[FileType] = None - save_features: bool = False - extract_json: Optional[str] = None - - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @field_validator("extract_json", mode="before") - def compute_extract_json(cls, v: str): - if v is None: - return f"extract_{int(time())}.json" - elif not v.endswith(".json"): - return f"{v}.json" - else: - return v - - -class ExtractionImageConfig(ExtractionConfig): - extract_method: ExtractionMethod = ExtractionMethod.IMAGE - - -class ExtractionPatchConfig(ExtractionConfig): - patch_size: int = 50 - stride_size: int = 50 - extract_method: ExtractionMethod = ExtractionMethod.PATCH - - -class ExtractionSliceConfig(ExtractionConfig): - slice_direction: SliceDirection = SliceDirection.SAGITTAL - slice_mode: SliceMode = SliceMode.RGB - num_slices: Optional[NonNegativeInt] = None - discarded_slices: Tuple[NonNegativeInt, NonNegativeInt] = (0, 0) - extract_method: ExtractionMethod = ExtractionMethod.SLICE - - -class ExtractionROIConfig(ExtractionConfig): - roi_list: List[str] = [] - roi_uncrop_output: bool = False - roi_custom_template: str = "" - roi_custom_pattern: str = "" - roi_custom_suffix: str = "" - roi_custom_mask_pattern: str = "" - roi_background_value: int = 0 - extract_method: ExtractionMethod = ExtractionMethod.ROI diff --git a/clinicadl/dataset/config/file_type.py b/clinicadl/dataset/config/file_type.py new file mode 100644 index 000000000..8a1249ca6 --- /dev/null +++ b/clinicadl/dataset/config/file_type.py @@ -0,0 +1,47 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from clinicadl.utils.config import ClinicaDLConfig +from clinicadl.utils.enum import Preprocessing + + +class FileType(ClinicaDLConfig): + """ + Represents a file type with a pattern, description, and optional pipeline requirement. + """ + + pattern: str + description: str + needed_pipeline: Optional[str] = None + + @field_validator("pattern", mode="before") + def check_pattern(cls, v): + if not v: + raise ValueError("A pattern must be specified") + + elif v[0] == "/": + raise ValueError( + "pattern argument cannot start with char: / (does not work in os.path.join function). " + "If you want to indicate the exact name of the file, use the format " + "directory_name/filename.extension or filename.extension in the pattern argument." + ) + return v + + @field_validator("description", mode="before") + def check_description(cls, v): + if not v: + raise ValueError("A pattern must be specified") + return v + + @field_validator("needed_pipeline", mode="before") + def check_needed_pipeline(cls, v): + if v: + try: + v = Preprocessing(v) + except ValueError: + raise ValueError( + f"Invalid pipeline: {v}. Choose from {[e.value for e in Preprocessing]}" + ) + return v diff --git a/clinicadl/dataset/config/preprocessing.py b/clinicadl/dataset/config/preprocessing.py index ad8db765e..5889ca92c 100644 --- a/clinicadl/dataset/config/preprocessing.py +++ b/clinicadl/dataset/config/preprocessing.py @@ -1,57 +1,221 @@ +import abc from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, computed_field, field_validator from clinicadl.utils.enum import ( DTIMeasure, DTISpace, + ImageModality, + LinearModality, Preprocessing, SUVRReferenceRegions, Tracer, ) +from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.modality_config") -class PreprocessingConfig(BaseModel): +class PreprocessingConfig(BaseModel, abc.ABC): """ Abstract config class for the preprocessing procedure. """ - tsv_file: Optional[Path] = None preprocessing: Preprocessing use_uncropped_image: bool = False - # pydantic config - model_config = ConfigDict(validate_assignment=True) - + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + def get_filetype(self, bids: bool = False) -> FileType: + return self.get_bids_filetype() if bids else self.get_caps_filetype() + + @abc.abstractmethod + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + """Abstract method to get the BIDS filetype.""" + pass + + @abc.abstractmethod + def get_caps_filetype(self) -> FileType: + """Abstract method to obtain FileType details.""" + pass + + @computed_field + @property + def file_type(self) -> FileType: + if self.preprocessing not in Preprocessing: + raise NotImplementedError( + f"Extraction of preprocessing {self.preprocessing.value} is not implemented from CAPS directory." + ) + else: + return self.get_filetype() + + def linear_nii( + self, modality: LinearModality, needed_pipeline: Preprocessing + ) -> FileType: + """ + Constructs the file type for linear caps image data + """ + desc_crop = "" if self.use_uncropped_image else "_desc-Crop" + + file_type = FileType( + pattern=f"{self.preprocessing.value.replace('-', '_')}/*space-MNI152NLin2009cSym{desc_crop}_res-1x1x1_{modality.value}.nii.gz", + description=f"{modality.value} Image registered in MNI152NLin2009cSym space using {needed_pipeline.value} pipeline " + + ( + "" + if self.use_uncropped_image + else "and cropped (matrix size 169×208×179, 1 mm isotropic voxels)" + ), + needed_pipeline=needed_pipeline, + ) + return file_type + + +class PreprocessingPET(PreprocessingConfig): + """ + Configuration for PET image preprocessing + """ -class PETPreprocessingConfig(PreprocessingConfig): tracer: Tracer = Tracer.FFDG suvr_reference_region: SUVRReferenceRegions = SUVRReferenceRegions.CEREBELLUMPONS2 preprocessing: Preprocessing = Preprocessing.PET_LINEAR + @field_validator("tracer", mode="before") + def check_tracer(cls, v: Union[str, Tracer]): + return Tracer(v) + + @field_validator("suvr_reference_region", mode="before") + def check_suvr_reference_region(cls, v: Union[str, SUVRReferenceRegions]): + return SUVRReferenceRegions(v) + + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + trc, rec, description = "", "", "PET data" + if self.tracer: + description += f" with {self.tracer.value} tracer" + trc = f"_trc-{self.tracer.value}" + if reconstruction: + description += f" and reconstruction method {reconstruction}" + rec = f"_rec-{reconstruction}" + + return FileType(pattern=f"pet/*{trc}{rec}_pet.nii*", description=description) + + def get_caps_filetype(self) -> FileType: + des_crop = "" if self.use_uncropped_image else "_desc-Crop" + + return FileType( + pattern=f"pet_linear/*_trc-{self.tracer.value}_space-MNI152NLin2009cSym{des_crop}_res-1x1x1_suvr-{self.suvr_reference_region.value}_pet.nii.gz", + description="", + needed_pipeline="pet-linear", + ) + + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} PET images with tracer {self.tracer.value} and suvr reference region {self.suvr_reference_region.value}. " + + +class PreprocessingCustom(PreprocessingConfig): + """ + Configuration for custom preprocessing with a user-defined suffix. + """ -class CustomPreprocessingConfig(PreprocessingConfig): custom_suffix: str = "" preprocessing: Preprocessing = Preprocessing.CUSTOM + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + return FileType( + pattern=f"*{self.custom_suffix}", + description="Custom suffix", + ) + + def get_caps_filetype(self) -> FileType: + return FileType( + pattern=f"custom/*{self.custom_suffix}", + description="Custom suffix", + ) + + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} custom images with suffix {self.custom_suffix} " + + +class PreprocessingDTI(PreprocessingConfig): + """ + Configuration for DTI-based preprocessing + """ -class DTIPreprocessingConfig(PreprocessingConfig): dti_measure: DTIMeasure = DTIMeasure.FRACTIONAL_ANISOTROPY dti_space: DTISpace = DTISpace.ALL preprocessing: Preprocessing = Preprocessing.DWI_DTI + def get_bids_filerype(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="dwi/sub-*_ses-*_dwi.nii*", description="DWI NIfTI") + + def get_caps_filetype(self) -> FileType: + """Return the query dict required to capture DWI DTI images. + + Parameters + ---------- + config: PreprocessingDTI + + Returns + ------- + FileType : + """ + measure = self.dti_measure + space = self.dti_space -class T1PreprocessingConfig(PreprocessingConfig): + return FileType( + pattern=f"dwi/dti_based_processing/*/*_space-{space}_{measure.value}.nii.gz", + description=f"DTI-based {measure.value} in space {space}.", + needed_pipeline="dwi_dti", + ) + + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} DTI images with measure {self.dti_measure.value} and space {self.dti_space.value}. " + + +class PreprocessingT1(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.T1_LINEAR + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="anat/sub-*_ses-*_T1w.nii*", description="T1w MRI") + + def get_caps_filetype(self) -> FileType: + return self.linear_nii( + modality=LinearModality.T1W, needed_pipeline=Preprocessing.T1_LINEAR + ) -class FlairPreprocessingConfig(PreprocessingConfig): + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} T1 images with t1-linear pipeline" + + +class PreprocessingFlair(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.FLAIR_LINEAR + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="sub-*_ses-*_flair.nii*", description="FLAIR T2w MRI") + + def get_caps_filetype(self) -> FileType: + return self.linear_nii( + modality=LinearModality.FLAIR, needed_pipeline=Preprocessing.FLAIR_LINEAR + ) + + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} Flair images with flair-linear pipeline" -class T2PreprocessingConfig(PreprocessingConfig): + +class PreprocessingT2(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.T2_LINEAR + + def get_bids_filetype(self, reconstruction: Optional[str] = None) -> FileType: + raise NotImplementedError( + f"Extraction of preprocessing {self.preprocessing.value} is not implemented from BIDS directory." + ) + + def get_caps_filetype(self) -> FileType: + return self.linear_nii( + modality=LinearModality.T2W, needed_pipeline=Preprocessing.T2_LINEAR + ) + + def __str__(self): + return f"Preprocessing of {'uncropped' if self.use_uncropped_image else 'cropped'} T2 images with t2-linear pipeline" diff --git a/clinicadl/dataset/data_config.py b/clinicadl/dataset/data_config.py deleted file mode 100644 index 39e6a6254..000000000 --- a/clinicadl/dataset/data_config.py +++ /dev/null @@ -1,164 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union - -import pandas as pd -from pydantic import BaseModel, ConfigDict, computed_field, field_validator - -from clinicadl.utils.enum import Mode -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLTSVError, -) -from clinicadl.utils.iotools.clinica_utils import check_caps_folder -from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test -from clinicadl.utils.iotools.utils import read_preprocessing - -logger = getLogger("clinicadl.data_config") - - -class DataConfig(BaseModel): # TODO : put in data module - """Config class to specify the data. - - caps_directory and preprocessing_json are arguments - that must be passed by the user. - """ - - caps_directory: Optional[Path] = None - baseline: bool = False - diagnoses: Tuple[str, ...] = ("AD", "CN") - data_df: Optional[pd.DataFrame] = None - label: Optional[str] = None - label_code: Union[str, Dict[str, int], None] = {} - multi_cohort: bool = False - mask_path: Optional[Path] = None - preprocessing_json: Optional[Path] = None - data_tsv: Optional[Path] = None - n_subjects: int = 300 - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @field_validator("diagnoses", mode="before") - def validator_diagnoses(cls, v): - """Transforms a list to a tuple.""" - if isinstance(v, list): - return tuple(v) - return v # TODO : check if columns are in tsv - - def create_groupe_df(self): - group_df = None - if self.data_tsv is not None and self.data_tsv.is_file(): - group_df = load_data_test( - self.data_tsv, - self.diagnoses, - multi_cohort=self.multi_cohort, - ) - return group_df - - def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): - return ( - self.label is not None - and self.label != "" - and self.label != _label - and _label_code == "default" - ) - - def check_label(self, _label: str): - if not self.label: - self.label = _label - - @field_validator("data_tsv", mode="before") - @classmethod - def check_data_tsv(cls, v) -> Path: - if v is not None: - if not isinstance(v, Path): - v = Path(v) - if not v.is_file(): - raise ClinicaDLTSVError( - "The participants_list you gave is not a file. Please give an existing file." - ) - if v.stat().st_size == 0: - raise ClinicaDLTSVError( - "The participants_list you gave is empty. Please give a non-empty file." - ) - return v - - @computed_field - @property - def caps_dict(self) -> Dict[str, Path]: - if self.multi_cohort: - if self.caps_directory.suffix != ".tsv": - raise ClinicaDLArgumentError( - "If multi_cohort is True, the CAPS_DIRECTORY argument should be a path to a TSV file." - ) - else: - caps_df = pd.read_csv(self.caps_directory, sep="\t") - check_multi_cohort_tsv(caps_df, "CAPS") - caps_dict = dict() - for idx in range(len(caps_df)): - cohort = caps_df.loc[idx, "cohort"] - caps_path = Path(caps_df.at[idx, "path"]) - check_caps_folder(caps_path) - caps_dict[cohort] = caps_path - else: - check_caps_folder(self.caps_directory) - caps_dict = {"single": self.caps_directory} - - return caps_dict - - @computed_field - @property - def preprocessing_dict(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - - Raises - ------ - ValueError - In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. - """ - - if self.preprocessing_json is not None: - if not self.multi_cohort: - preprocessing_json = ( - self.caps_directory / "tensor_extraction" / self.preprocessing_json - ) - else: - caps_dict = self.caps_dict - json_found = False - for caps_name, caps_path in caps_dict.items(): - preprocessing_json = ( - caps_path / "tensor_extraction" / self.preprocessing_json - ) - if preprocessing_json.is_file(): - logger.info( - f"Preprocessing JSON {preprocessing_json} found in CAPS {caps_name}." - ) - json_found = True - if not json_found: - raise ValueError( - f"Preprocessing JSON {self.preprocessing_json} was not found for any CAPS " - f"in {caps_dict}." - ) - - preprocessing_dict = read_preprocessing(preprocessing_json) - - if ( - preprocessing_dict["mode"] == "roi" - and "roi_background_value" not in preprocessing_dict - ): - preprocessing_dict["roi_background_value"] = 0 - - return preprocessing_dict - else: - return None - - @computed_field - @property - def mode(self) -> Mode: - return Mode(self.preprocessing_dict["mode"]) diff --git a/clinicadl/dataset/dataloader_config.py b/clinicadl/dataset/dataloader_config.py deleted file mode 100644 index cc01ba9a9..000000000 --- a/clinicadl/dataset/dataloader_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from logging import getLogger - -from pydantic import BaseModel, ConfigDict -from pydantic.types import PositiveInt - -from clinicadl.utils.enum import Sampler - -logger = getLogger("clinicadl.dataloader_config") - - -class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module - """Config class to configure the DataLoader.""" - - batch_size: PositiveInt = 8 - n_proc: PositiveInt = 2 - sampler: Sampler = Sampler.RANDOM - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/dataset/prepare_data/__init__.py b/clinicadl/dataset/datasets/___init__.py similarity index 100% rename from clinicadl/dataset/prepare_data/__init__.py rename to clinicadl/dataset/datasets/___init__.py diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/dataset/datasets/caps_dataset.py new file mode 100644 index 000000000..4e3a9c9c1 --- /dev/null +++ b/clinicadl/dataset/datasets/caps_dataset.py @@ -0,0 +1,530 @@ +# coding: utf8 +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import nibabel as nib +import pandas as pd +import torch +from joblib import Parallel, delayed +from pydantic import NonNegativeInt, PositiveInt +from torch import save as save_tensor +from torch.utils.data import Dataset +from tqdm import tqdm + +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.readers.caps_reader import CapsReader +from clinicadl.dataset.transforms.extraction import Image +from clinicadl.dataset.utils import ( + CapsDatasetSample, + check_df, + get_infos_from_json, + tsv_to_df, +) +from clinicadl.transforms.transforms import Transforms +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.clinica_utils import create_subs_sess_list + +logger = getLogger("clinicadl.caps_dataset") + +PARTICIPANT_ID = "participant_id" +SESSION_ID = "session_id" + + +class CapsDataset(Dataset): + """ + CapsDataset is a custom PyTorch Dataset class for working with neuroimaging data in CAPS format. + + The dataset supports preprocessing, data augmentation, extraction of specific image + features (e.g., slices, patches, ROIs), and parallelized preparation of tensor files. + + Parameters + ---------- + caps_reader: CapsReader + Reader object for handling CAPS directories. + preprocessing: PreprocessingConfig + Configuration of preprocessing applied to the data. + transforms: Transforms + Transformation pipeline to apply to the data. + df: pd.DataFrame + DataFrame containing participant/session information. + elem_per_image: int + Number of elements per image, determined by the extraction mode. + eval_mode: bool + Flag indicating whether the dataset is in evaluation mode. + """ + + def __init__( + self, + caps_directory: Path, + preprocessing: PreprocessingConfig, + transforms: Transforms, + data: Optional[Union[pd.DataFrame, Path]] = None, + ): + """ + Initializes the CapsDataset. + + Parameters + ---------- + caps_directory : Path + Path to the CAPS directory containing the neuroimaging data. + preprocessing : PreprocessingConfig + Configuration for the preprocessing steps applied to the data. + transforms : Transforms + Transformation pipeline to apply to the data during loading. + data : Union[pd.DataFrame, Path], optional + Data source, either a TSV file or a pre-loaded DataFrame with participant/session information. + """ + + self.eval_mode = False + self.caps_reader = CapsReader(caps_directory) + self.preprocessing = preprocessing + self.transforms = transforms + self.extraction = transforms.extraction + self.df = self._get_df_from_input(data) + + # self.size = self[0].elem.size() + + @property + def elem_per_image(self): + """ + Returns the number of elements per image based on the extraction mode. + + The value is determined by extracting the first image in the dataset and checking how many + elements are present in that image according to the extraction method. + + Returns + ------- + int + Number of elements per image. + """ + if not hasattr(self, "_elem_per_image"): + self._elem_per_image = self.extraction.num_elem_per_image( + image=self._get_full_image()[0] + ) + return self._elem_per_image + + @classmethod + def from_json(cls, json_path: Path): + """ + Creates a CapsDataset instance from a JSON configuration file. + + This method loads the preprocessing configuration, transformation pipeline, CAPS directory, + and data source (TSV or DataFrame) from the provided JSON file, and returns an instance + of the CapsDataset. + + Parameters + ---------- + json_path : Path + Path to the JSON file containing the necessary configuration for creating the dataset. + + Returns + ------- + CapsDataset + The initialized CapsDataset instance. + + Raises + ------ + FileNotFoundError + If the provided JSON file does not exist. + """ + + if not json_path.is_file(): + raise FileNotFoundError( + f"The provided preprocessing JSON file {json_path} does not exist." + ) + + preprocessing, transforms, caps_dir, data_tsv = get_infos_from_json(json_path) + return CapsDataset( + caps_dir, + preprocessing, + transforms, + data_tsv, + ) + + def describe(self): + """To complete/merge later with the dataset_description from clinica""" + return { + "total_samples": self.__len__(), + "elem_per_image": self._elem_per_image, + "participants": self.df[PARTICIPANT_ID].nunique(), + "sessions": self.df[SESSION_ID].nunique(), + "preprocessing": self.preprocessing.model_dump(), + "extraction": self.extraction.model_dump(), + } + + def _get_df_from_input( + self, data: Optional[Union[pd.DataFrame, Path]] = None + ) -> pd.DataFrame: + """ + Generates or validates the DataFrame from the input data. + + Parameters + ---------- + data : Union[pd.DataFrame, Path], optional + Path to the TSV file or a DataFrame containing participant/session pairs. + + Returns + ------- + pd.DataFrame + Validated DataFrame containing participant/session information. + + Raises + ------ + ClinicaDLTSVError + If the provided TSV file does not exist. + ClinicaDLCAPSError + If the data does not match the preprocessing configuration. + """ + + if data is None: + data = create_subs_sess_list( + self.caps_reader.input_directory, self.caps_reader.input_directory + ) + logger.info(f"Creating a subject session TSV file at {data}") + + if isinstance(data, Path): + if not data.is_file(): + raise ClinicaDLTSVError( + f"The data file does not exist: {data}" + "Please ensure the file path is correct and accessible." + ) + df = tsv_to_df(data) + elif isinstance(data, pd.DataFrame): + df = check_df(data) + + self.df = df + if not self._check_preprocessing_config(): + raise ClinicaDLCAPSError( + f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" + ) + + return df + + def _check_preprocessing_config(self) -> bool: + """ + Validates that the preprocessing configuration matches the data. + + Returns + ------- + bool + True if the configuration is valid, otherwise raises an error. + + Raises + ------ + ClinicaDLConfigurationError + If the preprocessing configuration does not match the data. + """ + pattern = self.preprocessing.file_type.pattern + for participant, session in self._get_participants_sessions_couple(): + folder = self.caps_reader.get_session_path( + participant=participant, session=session + ) + if not list(folder.glob(pattern)): + raise ClinicaDLConfigurationError( + f"Could not find preprocessing {self.preprocessing.preprocessing.value} for participant {participant} and session {session} with pattern: {pattern}" + ) + return True + + def __len__(self) -> int: + """ + Computes the total number of samples in the dataset. + + Returns + ------- + int + Total number of elements in the dataset. + """ + return len(self.df) * self.elem_per_image + + def _get_meta_data( + self, idx: NonNegativeInt + ) -> Tuple[str, str, NonNegativeInt, NonNegativeInt]: + """ + Retrieves metadata for a given sample index. + + Parameters + ---------- + idx : NonNegativeInt + Index of the sample. + + Returns + ------- + tuple + - participant (str): ID of the participant. + - session (str): ID of the session. + - img_index (NonNegativeInt): Index of the image. + - elem_index (NonNegativeInt): Index of the extracted element. + + Raises + ------ + IndexError + If the index is out of range. + """ + if idx >= self.__len__(): + raise IndexError( + f"Index out of range, there are only {self.__len__()} elements in your dataset." + ) + + img_idx = idx // self.elem_per_image + elem_idx = idx % self.elem_per_image + + participant = self._get_participant(img_idx) + session = self._get_session(img_idx) + + return participant, session, img_idx, elem_idx + + def _get_participant(self, idx: NonNegativeInt) -> str: + """ + Retrieves the participant ID for a given row index. + + Parameters + ---------- + idx : NonNegativeInt + Row index. + + Returns + ------- + str + Participant ID. + """ + return self.df.at[idx, PARTICIPANT_ID] + + def _get_session(self, idx: NonNegativeInt) -> str: + """ + Retrieves the session ID for a given row index. + + Parameters + ---------- + idx : NonNegativeInt + Row index. + + Returns + ------- + str + Session ID. + """ + return self.df.at[idx, SESSION_ID] + + def _get_participants_sessions_couple(self) -> List[Tuple[str, str]]: + """ + Retrieves all participant-session pairs in the dataset. + + Returns + ------- + List[Tuple[str, str]] + A list of tuples where each tuple contains a participant ID and a session ID. + """ + return list(zip(self.df[PARTICIPANT_ID], self.df[SESSION_ID])) + + def _get_full_image( + self, idx: NonNegativeInt = 0, weights_only: bool = True + ) -> tuple[torch.Tensor, Path]: + """ + Retrieves the full image tensor and its path for a given index. + + Parameters + ---------- + idx : NonNegativeInt, optional + Index of the image (default is 0). + weights_only : bool, optional + If True, only the tensor's data weights are loaded (default is True). + + Returns + ------- + tuple + A tuple containing: + - torch.Tensor: The full image tensor. + - Path: The path to the image file. + + Raises + ------ + FileNotFoundError + If the image file does not exist in the CAPS directory. + """ + + participant_id = self._get_participant(idx) + session_id = self._get_session(idx) + + image_path = self.caps_reader.get_tensor_path( + participant_id, session_id, self.preprocessing + ) + if image_path.is_file(): + image = torch.load(image_path, weights_only=weights_only) + else: + image_path = self.caps_reader.get_image_path( + participant_id, session_id, self.preprocessing + ) + image_nii = nib.loadsave.load(image_path) # type: ignore + image_np = image_nii.get_fdata() # type: ignore + image = ( + torch.from_numpy(image_np).unsqueeze(0).float() + ) # ToTensor()(image_np) ??? + + return image, image_path + + def __getitem__(self, idx: NonNegativeInt) -> CapsDatasetSample: + """ + Retrieves the sample at a given index. + + Parameters + ---------- + idx : NonNegativeInt + Index of the sample. + + Returns + ------- + CapsDatasetSample + A structured output containing the processed data and metadata. + """ + + if not isinstance(idx, int) or idx < 0: + raise ValueError(f"Index must be a non-negative integer, got {idx}.") + + participant, session, img_index, elem_index = self._get_meta_data(idx) + image, image_path = self._get_full_image(img_index, True) + + ( + image_trf, + object_trf, + image_augmentation, + object_augmentation, + ) = self.transforms.get_transforms() + + image = image_trf(image) + + if image_augmentation and not self.eval_mode: + image = image_augmentation(image) + + if not isinstance(self.extraction, Image): + tensor = self.transforms.extraction.extract_tensor( + image, + elem_index, + ) + if object_trf: + tensor = object_trf(tensor) + + if object_augmentation and not self.eval_mode: + tensor = object_augmentation(tensor) + + out = tensor + + else: + out = image + + sample = CapsDatasetSample( + elem=out, + # label=label, + participant_id=participant, + session_id=session, + img_idx=img_index, + elem_idx=elem_index, + image_path=image_path, + mode=self.extraction.extract_method, + ) + + return sample + + def eval(self): + """ + Sets the dataset to evaluation mode. + + This disables data augmentation in the transformation pipeline. + + Returns + ------- + CapsDataset + The dataset instance with evaluation mode enabled. + """ + self.eval_mode = True + return self + + def train(self): + """ + Sets the dataset to training mode. + + This enables data augmentation in the transformation pipeline. + + Returns + ------- + CapsDataset + The dataset instance with training mode enabled. + """ + self.eval_mode = False + return self + + def prepare_data( + self, + n_proc: PositiveInt = 2, + use_uncropped_images: bool = False, + ): + """ + Prepares tensor files from the neuroimaging data. + + This method processes the raw neuroimaging data (NIfTI format) into PyTorch tensors + and stores them for faster data loading during training and evaluation. + + Parameters + ---------- + n_proc : PositiveInt, optional + Number of processes to use for parallelization (default is 2). + use_uncropped_images : bool, optional + Whether to use uncropped images during preprocessing (default is False). + + Notes + ----- + - If the tensor file for a participant/session already exists, it will not be reprocessed. + - This method saves tensor files and image statistics (mean, std, min, max) for each image. + """ + + def prepare_image(participant, session): + image_path = self.caps_reader.get_image_path( + participant, session, self.preprocessing + ) + output_file_dir = self.caps_reader.get_tensor_dir( + participant, session, preprocessing=self.preprocessing + ) + + output_file_dir.mkdir(parents=True, exist_ok=True) + output_file = output_file_dir / Path(image_path).name.replace( + ".nii.gz", ".pt" + ) + + if output_file.is_file(): + logger.info( + f"The file '{output_file}' already exists, the tensor has already been extracted." + ) + else: + logger.debug(f"Processing of {image_path}.") + image_array = nib.loadsave.load(image_path).get_fdata(dtype="float32") # type: ignore + + # get some important infos about the image + info_df = pd.DataFrame( + [ + { + "mean": image_array.mean(), + "std": image_array.std(), + "max": image_array.max(), + "min": image_array.min(), + } + ] + ) + info_df.to_csv( + output_file_dir / "image_info.tsv", sep="\t", index=False + ) + + # extract and save the image tensor + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + save_tensor(image_tensor.clone(), output_file) + logger.debug(f"Output tensor saved at {output_file}") + + Parallel(n_jobs=n_proc)( + delayed(prepare_image)(participant, session) + for participant, session in tqdm( + self._get_participants_sessions_couple(), desc="Preparing data" + ) + ) diff --git a/clinicadl/dataset/datasets/concat.py b/clinicadl/dataset/datasets/concat.py new file mode 100644 index 000000000..80e748dac --- /dev/null +++ b/clinicadl/dataset/datasets/concat.py @@ -0,0 +1,51 @@ +# coding: utf8 +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.utils import CapsDatasetOutput + +logger = getLogger("clinicadl") + + +class ConcatDataset(CapsDataset): + def __init__(self, datasets: List[CapsDataset]): + self._datasets = datasets + self._len = sum(len(dataset) for dataset in datasets) + self._indexes = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + logger.debug(f"Datasets summary length: {self._len}") + logger.debug(f"Datasets indexes: {self._indexes}") + + self.check_extraction() + + self.eval_mode = False + + def __getitem__(self, index: int) -> Optional[CapsDatasetOutput]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len + + def check_extraction(self): + extractions = [d.extraction for d in self._datasets] + if all( + i == extractions[0] for i in extractions + ): # check that all the CaspDataset have the same mode + self.extraction = extractions[0] + else: + raise AttributeError( + "All the CapsDataset must have the same extraction method: 'image','patch','roi','slice', etc." + ) diff --git a/clinicadl/dataset/prepare_data/prepare_data.py b/clinicadl/dataset/prepare_data/prepare_data.py deleted file mode 100644 index e702bb066..000000000 --- a/clinicadl/dataset/prepare_data/prepare_data.py +++ /dev/null @@ -1,230 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Optional - -from joblib import Parallel, delayed -from torch import save as save_tensor - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.dataset.config.extraction import ( - ExtractionConfig, - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.utils.enum import ExtractionMethod, Pattern, Preprocessing, Template -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import ( - check_caps_folder, - clinicadl_file_reader, - container_from_filename, - determine_caps_or_bids, - get_subject_session_list, -) -from clinicadl.utils.iotools.utils import write_preprocessing - -from .prepare_data_utils import check_mask_list - - -def DeepLearningPrepareData( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -): - logger = getLogger("clinicadl.prepare_data") - # Get subject and session list - if from_bids is not None: - try: - input_directory = Path(from_bids) - except ClinicaDLArgumentError: - logger.warning("Your BIDS directory doesn't exist.") - logger.debug(f"BIDS directory: {input_directory}.") - is_bids_dir = True - else: - input_directory = config.data.caps_directory - check_caps_folder(input_directory) - logger.debug(f"CAPS directory: {input_directory}.") - is_bids_dir = False - - subjects, sessions = get_subject_session_list( - input_directory, config.data.data_tsv, is_bids_dir, False, None - ) - - if config.extraction.save_features: - logger.info( - f"{config.extraction.extract_method.value}s will be extracted in Pytorch tensor from {len(sessions)} images." - ) - else: - logger.info( - f"Images will be extracted in Pytorch tensor from {len(sessions)} images." - ) - logger.info( - f"Information for {config.extraction.extract_method.value} will be saved in output JSON file and will be used " - f"during training for on-the-fly extraction." - ) - logger.debug(f"List of subjects: \n{subjects}.") - logger.debug(f"List of sessions: \n{sessions}.") - - # Select the correct filetype corresponding to modality - # and select the right folder output name corresponding to modality - logger.debug( - f"Selected images are preprocessed with {config.preprocessing} pipeline`." - ) - - mod_subfolder, file_type = compute_folder_and_file_type(config, from_bids) - - # Input file: - input_files = clinicadl_file_reader(subjects, sessions, input_directory, file_type)[ - 0 - ] - logger.debug(f"Selected image file name list: {input_files}.") - - def write_output_imgs(output_mode, container, subfolder): - # Write the extracted tensor on a .pt file - for filename, tensor in output_mode: - output_file_dir = ( - config.data.caps_directory - / container - / "deeplearning_prepare_data" - / subfolder - / mod_subfolder - ) - output_file_dir.mkdir(parents=True, exist_ok=True) - output_file = output_file_dir / filename - save_tensor(tensor, output_file) - logger.debug(f"Output tensor saved at {output_file}") - - if ( - config.extraction.extract_method == ExtractionMethod.IMAGE - or not config.extraction.save_features - ): - - def prepare_image(file): - from .prepare_data_utils import extract_images - - logger.debug(f"Processing of {file}.") - container = container_from_filename(file) - subfolder = "image_based" - output_mode = extract_images(Path(file)) - logger.debug("Image extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_image)(file) for file in input_files - ) - - elif config.extraction.save_features: - if config.extraction.extract_method == ExtractionMethod.SLICE: - assert isinstance(config.extraction, ExtractionSliceConfig) - - def prepare_slice(file): - from .prepare_data_utils import extract_slices - - assert isinstance(config.extraction, ExtractionSliceConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "slice_based" - output_mode = extract_slices( - Path(file), - slice_direction=config.extraction.slice_direction, - slice_mode=config.extraction.slice_mode, - discarded_slices=config.extraction.discarded_slices, - ) - logger.debug(f" {len(output_mode)} slices extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_slice)(file) for file in input_files - ) - - elif config.extraction.extract_method == ExtractionMethod.PATCH: - assert isinstance(config.extraction, ExtractionPatchConfig) - - def prepare_patch(file): - from .prepare_data_utils import extract_patches - - assert isinstance(config.extraction, ExtractionPatchConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "patch_based" - output_mode = extract_patches( - Path(file), - patch_size=config.extraction.patch_size, - stride_size=config.extraction.stride_size, - ) - logger.debug(f" {len(output_mode)} patches extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_patch)(file) for file in input_files - ) - - elif config.extraction.extract_method == ExtractionMethod.ROI: - assert isinstance(config.extraction, ExtractionROIConfig) - - def prepare_roi(file): - from .prepare_data_utils import extract_roi - - assert isinstance(config.extraction, ExtractionROIConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "roi_based" - if config.preprocessing == Preprocessing.CUSTOM: - if not config.extraction.roi_custom_template: - raise ClinicaDLArgumentError( - "A custom template must be defined when the modality is set to custom." - ) - roi_template = config.extraction.roi_custom_template - roi_mask_pattern = config.extraction.roi_custom_mask_pattern - else: - if config.preprocessing.preprocessing == Preprocessing.T1_LINEAR: - roi_template = Template.T1_LINEAR - roi_mask_pattern = Pattern.T1_LINEAR - elif config.preprocessing.preprocessing == Preprocessing.PET_LINEAR: - roi_template = Template.PET_LINEAR - roi_mask_pattern = Pattern.PET_LINEAR - elif ( - config.preprocessing.preprocessing == Preprocessing.FLAIR_LINEAR - ): - roi_template = Template.FLAIR_LINEAR - roi_mask_pattern = Pattern.FLAIR_LINEAR - - masks_location = input_directory / "masks" / f"tpl-{roi_template}" - - if len(config.extraction.roi_list) == 0: - raise ClinicaDLArgumentError( - "A list of regions of interest must be given." - ) - else: - check_mask_list( - masks_location, - config.extraction.roi_list, - roi_mask_pattern, - config.preprocessing.use_uncropped_image, - ) - - output_mode = extract_roi( - Path(file), - masks_location=masks_location, - mask_pattern=roi_mask_pattern, - cropped_input=not config.preprocessing.use_uncropped_image, - roi_names=config.extraction.roi_list, - uncrop_output=config.extraction.roi_uncrop_output, - ) - logger.debug("ROI extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_roi)(file) for file in input_files - ) - - else: - raise NotImplementedError( - f"Extraction is not implemented for mode {config.extraction.extract_method.value}." - ) - - # Save parameters dictionary - preprocessing_json_path = write_preprocessing( - config.extraction.model_dump(), config.data.caps_directory - ) - logger.info(f"Preprocessing JSON saved at {preprocessing_json_path}.") diff --git a/clinicadl/dataset/prepare_data/prepare_data_utils.py b/clinicadl/dataset/prepare_data/prepare_data_utils.py deleted file mode 100644 index 0acd2ec25..000000000 --- a/clinicadl/dataset/prepare_data/prepare_data_utils.py +++ /dev/null @@ -1,442 +0,0 @@ -# coding: utf8 -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from clinicadl.utils.enum import SliceDirection, SliceMode - - -############ -# SLICE # -############ -def compute_discarded_slices(discarded_slices: Union[int, tuple]) -> Tuple[int, int]: - if isinstance(discarded_slices, int): - begin_discard, end_discard = discarded_slices, discarded_slices - elif len(discarded_slices) == 1: - begin_discard, end_discard = discarded_slices[0], discarded_slices[0] - elif len(discarded_slices) == 2: - begin_discard, end_discard = discarded_slices[0], discarded_slices[1] - else: - raise IndexError( - f"Maximum two number of discarded slices can be defined. " - f"You gave discarded slices = {discarded_slices}." - ) - return begin_discard, end_discard - - -def extract_slices( - nii_path: Path, - slice_direction: SliceDirection = SliceDirection.SAGITTAL, - slice_mode: SliceMode = SliceMode.SINGLE, - discarded_slices: Union[int, tuple] = 0, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts the slices from three directions - This function extracts slices form the preprocessed nifti image. - - The direction of extraction can be defined either on sagittal direction (0), - coronal direction (1) or axial direction (other). - - The output slices can be stored following two modes: - single (1 channel) or rgb (3 channels, all the same). - - Args: - nii_path: path to the NifTi input image. - slice_direction: along which axis slices are extracted. - slice_mode: 'single' or 'rgb'. - discarded_slices: Number of slices to discard at the beginning and the end of the image. - Will be a tuple of two integers if the number of slices to discard at the beginning - and at the end differ. - Returns: - list of tuples containing the path to the extracted slice - and the tensor of the corresponding slice. - """ - import nibabel as nib - - image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - - begin_discard, end_discard = compute_discarded_slices(discarded_slices) - index_list = range( - begin_discard, image_tensor.shape[int(slice_direction.value) + 1] - end_discard - ) - - slice_list = [] - for slice_index in index_list: - slice_tensor = extract_slice_tensor( - image_tensor, slice_direction, slice_mode, slice_index - ) - slice_path = extract_slice_path( - nii_path, slice_direction, slice_mode, slice_index - ) - - slice_list.append((slice_path, slice_tensor)) - - return slice_list - - -def extract_slice_tensor( - image_tensor: torch.Tensor, - slice_direction: SliceDirection, - slice_mode: SliceMode, - slice_index: int, -) -> torch.Tensor: - # Allow to select the slice `slice_index` in dimension `slice_direction` - idx_tuple = tuple( - [slice(None)] * (int(slice_direction.value) + 1) - + [slice_index] - + [slice(None)] * (2 - int(slice_direction.value)) - ) - slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L - - if slice_mode == "rgb": - slice_tensor = torch.cat( - (slice_tensor, slice_tensor, slice_tensor) - ) # shape is 3 * W * L - - return slice_tensor.clone() - - -def extract_slice_path( - img_path: Path, - slice_direction: SliceDirection, - slice_mode: SliceMode, - slice_index: int, -) -> str: - slice_dict = {0: "sag", 1: "cor", 2: "axi"} - input_img_filename = img_path.name - txt_idx = input_img_filename.rfind("_") - it_filename_prefix = input_img_filename[0:txt_idx] - it_filename_suffix = input_img_filename[txt_idx:] - it_filename_suffix = it_filename_suffix.replace(".nii.gz", ".pt") - return ( - f"{it_filename_prefix}_axis-{slice_dict[int(slice_direction.value)]}" - f"_channel-{slice_mode.value}_slice-{slice_index}{it_filename_suffix}" - ) - - -############ -# PATCH # -############ -def extract_patches( - nii_path: Path, - patch_size: int, - stride_size: int, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts the patches - This function extracts patches form the preprocessed nifti image. Patch size - if provided as input and also the stride size. If stride size is smaller - than the patch size an overlap exist between consecutive patches. If stride - size is equal to path size there is no overlap. Otherwise, unprocessed - zones can exits. - Args: - nii_path: path to the NifTi input image. - patch_size: size of a single patch. - stride_size: size of the stride leading to next patch. - Returns: - list of tuples containing the path to the extracted patch - and the tensor of the corresponding patch. - """ - import nibabel as nib - - image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - - patches_tensor = ( - image_tensor.unfold(1, patch_size, stride_size) - .unfold(2, patch_size, stride_size) - .unfold(3, patch_size, stride_size) - .contiguous() - ) - patches_tensor = patches_tensor.view(-1, patch_size, patch_size, patch_size) - - patch_list = [] - for patch_index in range(patches_tensor.shape[0]): - patch_tensor = extract_patch_tensor( - image_tensor, patch_size, stride_size, patch_index, patches_tensor - ) - patch_path = extract_patch_path(nii_path, patch_size, stride_size, patch_index) - - patch_list.append((patch_path, patch_tensor)) - - return patch_list - - -def extract_patch_tensor( - image_tensor: torch.Tensor, - patch_size: int, - stride_size: int, - patch_index: int, - patches_tensor: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Extracts a single patch from image_tensor""" - - if patches_tensor is None: - patches_tensor = ( - image_tensor.unfold(1, patch_size, stride_size) - .unfold(2, patch_size, stride_size) - .unfold(3, patch_size, stride_size) - .contiguous() - ) - - # the dimension of patches_tensor is [1, patch_num1, patch_num2, patch_num3, patch_size1, patch_size2, patch_size3] - patches_tensor = patches_tensor.view(-1, patch_size, patch_size, patch_size) - - return patches_tensor[patch_index, ...].unsqueeze_(0).clone() - - -def extract_patch_path( - img_path: Path, patch_size: int, stride_size: int, patch_index: int -) -> str: - input_img_filename = img_path.name - txt_idx = input_img_filename.rfind("_") - it_filename_prefix = input_img_filename[0:txt_idx] - it_filename_suffix = input_img_filename[txt_idx:] - it_filename_suffix = it_filename_suffix.replace(".nii.gz", ".pt") - - return f"{it_filename_prefix}_patchsize-{patch_size}_stride-{stride_size}_patch-{patch_index}{it_filename_suffix}" - - -############ -# IMAGE # -############ -def extract_images(input_img: Path) -> List[Tuple[str, torch.Tensor]]: - """Extract the images - This function convert nifti image to tensor (.pt) version of the image. - Tensor version is saved at the same location than input_img. - Args: - input_img: path to the NifTi input image. - Returns: - filename (str): single tensor file saved on the disk. Same location than input file. - """ - import nibabel as nib - import torch - - image_array = nib.loadsave.load(input_img).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - # make sure the tensor type is torch.float32 - output_file = ( - Path(input_img.name.replace(".nii.gz", ".pt")), - image_tensor.clone(), - ) - - return [output_file] - - -############ -# ROI # -############ -def check_mask_list( - masks_location: Path, roi_list: List[str], mask_pattern: str, cropping: bool -) -> None: - import nibabel as nib - import numpy as np - - for roi in roi_list: - roi_path, desc = find_mask_path(masks_location, roi, mask_pattern, cropping) - if roi_path is None: - raise FileNotFoundError( - f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" - ) - roi_mask = nib.loadsave.load(roi_path).get_fdata() - mask_values = set(np.unique(roi_mask)) - if mask_values != {0, 1}: - raise ValueError( - "The ROI masks used should be binary (composed of 0 and 1 only)." - ) - - -def find_mask_path( - masks_location: Path, roi: str, mask_pattern: str, cropping: bool -) -> Tuple[Union[None, str], str]: - """ - Finds masks corresponding to the pattern asked and containing the adequate cropping description - - Parameters - ---------- - masks_location: Path - Directory containing the masks. - roi: str - Name of the region. - mask_pattern: str - Pattern which should be found in the filename of the mask. - cropping: bool - If True the original image should contain the substring 'desc-Crop'. - - Returns - ------- - path of the mask or None if nothing was found. - a human-friendly description of the pattern looked for. - """ - - # Check that pattern begins and ends with _ to avoid mixing keys - if mask_pattern is None: - mask_pattern = "" - - candidates_pattern = f"*{mask_pattern}*_roi-{roi}_mask.nii*" - - desc = f"The mask should follow the pattern {candidates_pattern}. " - candidates = [e for e in masks_location.glob(candidates_pattern)] - if cropping is None: - # pass - candidates2 = candidates - elif cropping: - candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] - desc += "and contain '_desc-Crop_' string." - else: - candidates2 = [mask for mask in candidates if "_desc-Crop_" not in mask.name] - desc += "and not contain '_desc-Crop_' string." - - if len(candidates2) == 0: - return None, desc - else: - return min(candidates2), desc - - -def compute_output_pattern(mask_path: Path, crop_output: bool): - """ - Computes the output pattern of the region cropped (without the source file prefix) - Parameters - ---------- - mask_path: Path - Path to the masks - crop_output: bool - If True the output is cropped, and the descriptor CropRoi must exist - - Returns - ------- - the output pattern - """ - - mask_filename = mask_path.name - template_id = mask_filename.split("_")[0].split("-")[1] - mask_descriptors = mask_filename.split("_")[1:-2:] - roi_id = mask_filename.split("_")[-2].split("-")[1] - if "desc-Crop" not in mask_descriptors and crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - elif "desc-Crop" in mask_descriptors: - mask_descriptors = [ - descriptor for descriptor in mask_descriptors if descriptor != "desc-Crop" - ] - if crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - else: - mask_descriptors = ["desc-CropImage"] + mask_descriptors - - mask_pattern = "_".join(mask_descriptors) - - if mask_pattern == "": - output_pattern = f"space-{template_id}_roi-{roi_id}" - else: - output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" - - return output_pattern - - -def extract_roi( - nii_path: Path, - masks_location: Path, - mask_pattern: str, - cropped_input: bool, - roi_names: List[str], - uncrop_output: bool, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts regions of interest defined by masks - This function extracts regions of interest from preprocessed nifti images. - The regions are defined using binary masks that must be located in the CAPS - at `masks/tpl-