diff --git a/clinicadl/caps_dataset/caps_dataset_utils.py b/clinicadl/caps_dataset/caps_dataset_utils.py index 89e868934..40d5feca7 100644 --- a/clinicadl/caps_dataset/caps_dataset_utils.py +++ b/clinicadl/caps_dataset/caps_dataset_utils.py @@ -56,53 +56,6 @@ def read_json(json_path: Path) -> Dict[str, Any]: parameters["deterministic"] = not parameters["nondeterministic"] del parameters["nondeterministic"] - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig config = CapsDatasetConfig.from_preprocessing_and_extraction_method( @@ -110,8 +63,7 @@ def read_json(json_path: Path) -> Dict[str, Any]: preprocessing_type=parameters["preprocessing"], **parameters, ) - if "file_type" not in parameters["preprocessing_dict"]: - file_type = config.preprocessing.get_filetype() - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() + + file_type = config.preprocessing.get_filetype() return parameters diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/caps_dataset/data.py index 84cac4212..817c638e6 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/caps_dataset/data.py @@ -12,11 +12,13 @@ from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.caps_dataset.extraction.config import ( + ExtractionConfig, ExtractionImageConfig, ExtractionPatchConfig, ExtractionROIConfig, ExtractionSliceConfig, ) +from clinicadl.caps_dataset.preprocessing.config import PreprocessingConfig from clinicadl.prepare_data.prepare_data_utils import ( compute_discarded_slices, extract_patch_path, @@ -52,13 +54,18 @@ class CapsDataset(Dataset): def __init__( self, config: CapsDatasetConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionConfig, + transforms: TransformsConfig, label_presence: bool, - preprocessing_dict: Dict[str, Any], ): self.label_presence = label_presence self.eval_mode = False self.config = config - self.preprocessing_dict = preprocessing_dict + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms if not hasattr(self, "elem_index"): raise AttributeError( @@ -282,14 +289,15 @@ class CapsDatasetImage(CapsDataset): def __init__( self, config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + preprocessing: PreprocessingConfig, + extraction: ExtractionImageConfig, + transforms: TransformsConfig, label_presence: bool = True, ): """ Args: caps_directory: Directory of all the images. data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. train_transformations: Optional transform to be applied only on training mode. label_presence: If True the diagnosis will be extracted from the given DataFrame. label: Name of the column in data_df containing the label. @@ -301,11 +309,18 @@ def __init__( self.mode = "image" self.config = config + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms + self.label_presence = label_presence super().__init__( config=config, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, ) @property @@ -344,25 +359,33 @@ class CapsDatasetPatch(CapsDataset): def __init__( self, config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + preprocessing: PreprocessingConfig, + extraction: ExtractionPatchConfig, + transforms: TransformsConfig, patch_index: Optional[int] = None, label_presence: bool = True, ): """ caps_directory: Directory of all the images. data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. train_transformations: Optional transform to be applied only on training mode. """ self.patch_index = patch_index self.mode = "patch" self.config = config + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms + self.label_presence = label_presence super().__init__( config=config, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, ) @property @@ -375,14 +398,14 @@ def __getitem__(self, idx): ) image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.save_features: + if self.extraction.save_features: patch_dir = image_path.parent.as_posix().replace( "image_based", f"{self.mode}_based" ) patch_filename = extract_patch_path( image_path, - self.config.extraction.patch_size, - self.config.extraction.stride_size, + self.extraction.patch_size, + self.extraction.stride_size, patch_idx, ) patch_tensor = torch.load( @@ -393,8 +416,8 @@ def __getitem__(self, idx): image = torch.load(image_path, weights_only=True) patch_tensor = extract_patch_tensor( image, - self.config.extraction.patch_size, - self.config.extraction.stride_size, + self.extraction.patch_size, + self.extraction.stride_size, patch_idx, ) @@ -423,26 +446,26 @@ def num_elem_per_image(self): patches_tensor = ( image.unfold( 1, - self.config.extraction.patch_size, - self.config.extraction.stride_size, + self.extraction.patch_size, + self.extraction.stride_size, ) .unfold( 2, - self.config.extraction.patch_size, - self.config.extraction.stride_size, + self.extraction.patch_size, + self.extraction.stride_size, ) .unfold( 3, - self.config.extraction.patch_size, - self.config.extraction.stride_size, + self.extraction.patch_size, + self.extraction.stride_size, ) .contiguous() ) patches_tensor = patches_tensor.view( -1, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - self.config.extraction.patch_size, + self.extraction.patch_size, + self.extraction.patch_size, + self.extraction.patch_size, ) num_patches = patches_tensor.shape[0] return num_patches @@ -452,7 +475,9 @@ class CapsDatasetRoi(CapsDataset): def __init__( self, config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + preprocessing: PreprocessingConfig, + extraction: ExtractionROIConfig, + transforms: TransformsConfig, roi_index: Optional[int] = None, label_presence: bool = True, ): @@ -460,7 +485,6 @@ def __init__( Args: caps_directory: Directory of all the images. data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. roi_index: If a value is given the same region will be extracted for each image. else the dataset will load all the regions possible for one image. train_transformations: Optional transform to be applied only on training mode. @@ -474,14 +498,21 @@ def __init__( self.roi_index = roi_index self.mode = "roi" self.config = config + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms + self.label_presence = label_presence self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors( - self.config.data.caps_directory, preprocessing_dict + self.config.data.caps_directory ) super().__init__( config=config, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, ) @property @@ -492,19 +523,19 @@ def __getitem__(self, idx): participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.roi_list is None: + if self.extraction.roi_list is None: raise NotImplementedError( "Default regions are not available anymore in ClinicaDL. " "Please define appropriate masks and give a roi_list." ) - if self.config.extraction.save_features: + if self.extraction.save_features: mask_path = self.mask_paths[roi_idx] roi_dir = image_path.parent.as_posix().replace( "image_based", f"{self.mode}_based" ) roi_filename = extract_roi_path( - image_path, mask_path, self.config.extraction.roi_uncrop_output + image_path, mask_path, self.extraction.roi_uncrop_output ) roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) @@ -512,7 +543,7 @@ def __getitem__(self, idx): image = torch.load(image_path, weights_only=True) mask_array = self.mask_arrays[roi_idx] roi_tensor = extract_roi_tensor( - image, mask_array, self.config.extraction.uncropped_roi + image, mask_array, self.extraction.uncropped_roi ) train_trf, trf = self.config.transforms.get_transforms() @@ -535,15 +566,14 @@ def __getitem__(self, idx): def num_elem_per_image(self): if self.elem_index is not None: return 1 - if self.config.extraction.roi_list is None: + if self.extraction.roi_list is None: return 2 else: - return len(self.config.extraction.roi_list) + return len(self.extraction.roi_list) def _get_mask_paths_and_tensors( self, caps_directory: Path, - preprocessing_dict: Dict[str, Any], ) -> Tuple[List[str], List]: """Loads the masks necessary to regions extraction""" import nibabel as nib @@ -557,21 +587,21 @@ def _get_mask_paths_and_tensors( ) try: - preprocessing_ = Preprocessing(preprocessing_dict["preprocessing"]) + preprocessing_ = self.preprocessing.preprocessing except NotImplementedError: print( - f"Template of preprocessing {preprocessing_dict['preprocessing']} " + f"Template of preprocessing {self.preprocessing.preprocessing.value} " f"is not defined." ) # Find template name and pattern if preprocessing_.value == "custom": - template_name = preprocessing_dict["roi_custom_template"] + template_name = self.preprocessing.preprocessing.roi_custom_template if template_name is None: raise ValueError( "Please provide a name for the template when preprocessing is `custom`." ) - pattern = preprocessing_dict["roi_custom_mask_pattern"] + pattern = self.preprocessing.preprocessing.roi_custom_mask_pattern if pattern is None: raise ValueError( "Please provide a pattern for the masks when preprocessing is `custom`." @@ -589,7 +619,7 @@ def _get_mask_paths_and_tensors( mask_location = caps_directory / "masks" / f"tpl-{template_name}" mask_paths, mask_arrays = list(), list() - for roi in self.config.extraction.roi_list: + for roi in self.extraction.roi_list: logger.info(f"Find mask for roi {roi}.") mask_path, desc = find_mask_path(mask_location, roi, pattern, True) if mask_path is None: @@ -605,7 +635,9 @@ class CapsDatasetSlice(CapsDataset): def __init__( self, config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + preprocessing: PreprocessingConfig, + extraction: ExtractionSliceConfig, + transforms: TransformsConfig, slice_index: Optional[int] = None, label_presence: bool = True, ): @@ -613,7 +645,6 @@ def __init__( Args: caps_directory: Directory of all the images. data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. slice_index: If a value is given the same slice will be extracted for each image. else the dataset will load all the slices possible for one image. train_transformations: Optional transform to be applied only on training mode. @@ -626,12 +657,18 @@ def __init__( self.slice_index = slice_index self.mode = "slice" self.config = config + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms + self.label_presence = label_presence - self.preprocessing_dict = preprocessing_dict super().__init__( config=config, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, ) @property @@ -642,17 +679,17 @@ def __getitem__(self, idx): participant, session, cohort, slice_idx, label, domain = self._get_meta_data( idx ) - slice_idx = slice_idx + self.config.extraction.discarded_slices[0] + slice_idx = slice_idx + self.extraction.discarded_slices[0] image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.save_features: + if self.extraction.save_features: slice_dir = image_path.parent.as_posix().replace( "image_based", f"{self.mode}_based" ) slice_filename = extract_slice_path( image_path, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, + self.extraction.slice_direction, + self.extraction.slice_mode, slice_idx, ) slice_tensor = torch.load( @@ -664,8 +701,8 @@ def __getitem__(self, idx): image = torch.load(image_path, weights_only=True) slice_tensor = extract_slice_tensor( image, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, + self.extraction.slice_direction, + self.extraction.slice_mode, slice_idx, ) @@ -690,22 +727,23 @@ def num_elem_per_image(self): if self.elem_index is not None: return 1 - if self.config.extraction.num_slices is not None: - return self.config.extraction.num_slices + if self.extraction.num_slices is not None: + return self.extraction.num_slices image = self._get_full_image() return ( - image.size(int(self.config.extraction.slice_direction) + 1) - - self.config.extraction.discarded_slices[0] - - self.config.extraction.discarded_slices[1] + image.size(int(self.extraction.slice_direction) + 1) + - self.extraction.discarded_slices[0] + - self.extraction.discarded_slices[1] ) def return_dataset( input_dir: Path, data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionConfig, + transforms: TransformsConfig, label: Optional[str] = None, label_code: Optional[Dict[str, int]] = None, cnn_index: Optional[int] = None, @@ -717,7 +755,6 @@ def return_dataset( Args: input_dir: path to a directory containing a CAPS structure. data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. train_transformations: Optional transform to be applied during training only. all_transformations: Optional transform to be applied during training and evaluation. label: Name of the column in data_df containing the label. @@ -729,15 +766,15 @@ def return_dataset( Returns: the corresponding dataset. """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": + if cnn_index is not None and isinstance(extraction, ExtractionImageConfig): raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." + f"Multi-CNN is not implemented for {extraction.extract_method} mode." ) config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - preprocessing_type=preprocessing_dict["preprocessing"], - preprocessing=preprocessing_dict["preprocessing"], - extraction=preprocessing_dict["mode"], + preprocessing_type=preprocessing.preprocessing, + preprocessing=preprocessing.preprocessing, + extraction=extraction.extract_method, caps_directory=input_dir, data_df=data_df, label=label, @@ -746,72 +783,45 @@ def return_dataset( ) config.transforms = transforms_config - if preprocessing_dict["mode"] == "image": - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] + if isinstance(extraction, ExtractionImageConfig): return CapsDatasetImage( - config, + config=config, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, ) - elif preprocessing_dict["mode"] == "patch": - assert isinstance(config.extraction, ExtractionPatchConfig) - config.extraction.patch_size = preprocessing_dict["patch_size"] - config.extraction.stride_size = preprocessing_dict["stride_size"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] + elif isinstance(extraction, ExtractionPatchConfig): return CapsDatasetPatch( config, patch_index=cnn_index, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) - elif preprocessing_dict["mode"] == "roi": - assert isinstance(config.extraction, ExtractionROIConfig) - config.extraction.roi_list = preprocessing_dict["roi_list"] - config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] + elif isinstance(extraction, ExtractionROIConfig): return CapsDatasetRoi( config, roi_index=cnn_index, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) - elif preprocessing_dict["mode"] == "slice": - assert isinstance(config.extraction, ExtractionSliceConfig) - config.extraction.slice_direction = SliceDirection( - str(preprocessing_dict["slice_direction"]) - ) - config.extraction.slice_mode = SliceMode(preprocessing_dict["slice_mode"]) - config.extraction.discarded_slices = compute_discarded_slices( - preprocessing_dict["discarded_slices"] - ) - config.extraction.num_slices = ( - None - if "num_slices" not in preprocessing_dict - else preprocessing_dict["num_slices"] - ) - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] + elif isinstance(extraction, ExtractionSliceConfig): return CapsDatasetSlice( config, slice_index=cnn_index, label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) else: raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." + f"Mode {extraction.extract_method.value} is not implemented." ) diff --git a/clinicadl/caps_dataset/data_2.py b/clinicadl/caps_dataset/data_2.py new file mode 100644 index 000000000..9c854f345 --- /dev/null +++ b/clinicadl/caps_dataset/data_2.py @@ -0,0 +1,825 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.data_config import DataConfig +from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.caps_dataset.extraction.config import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.caps_dataset.preprocessing.config import PreprocessingConfig +from clinicadl.caps_dataset.utils import ( + CapsDatasetOutput, + get_preprocessing_and_mode_from_json, +) +from clinicadl.prepare_data.prepare_data_utils import ( + compute_discarded_slices, + extract_patch_path, + extract_patch_tensor, + extract_roi_path, + extract_roi_tensor, + extract_slice_path, + extract_slice_tensor, + find_mask_path, +) +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.enum import ( + ExtractionMethod, + Pattern, + Preprocessing, + SliceDirection, + SliceMode, + Template, +) +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) + +logger = getLogger("clinicadl") + + +################################# +# Datasets loaders +################################# +class CapsDataset(Dataset): + """Abstract class for all derived CapsDatasets.""" + + def __init__( + self, + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionConfig, + transforms: TransformsConfig, + label_presence: bool, + ): + self.label_presence = label_presence + self.eval_mode = False + self.elem_per_image = self.num_elem_per_image() + self.size = self[0]["image"].size() + + self.preprocessing = preprocessing + self.extraction = extraction + self.transforms = transforms + self.data = data + self.caps_dict = data.caps_dict + + if not hasattr(self, "elem_index"): + raise AttributeError( + "Child class of CapsDataset must set elem_index attribute." + ) + if not hasattr(self, "mode"): + raise AttributeError("Child class of CapsDataset, must set mode attribute.") + + self.df = pd.read_csv(data.tsv_path, sep="\t") + + mandatory_col = { + "participant_id", + "session_id", + "cohort", + } + + # if label_presence and self.config.data.label is not None: + # mandatory_col.add(self.config.data.label) + + if not mandatory_col.issubset(set(self.df.columns.values)): + raise ClinicaDLTSVError( + f"the data file is not in the correct format." + f"Columns should include {mandatory_col}" + ) + + @classmethod + def from_extract_json( + cls, + data: DataConfig, + extract_json: str, + transforms: TransformsConfig, + label_presence: bool, + ): + extract_json_path = data.caps_directory / "tensor_extraction" / extract_json + if not extract_json_path.is_file(): + raise ClinicaDLConfigurationError(f"Could not find {extract_json_path}") + + preprocessing, extraction = get_preprocessing_and_mode_from_json( + extract_json_path + ) + + return cls( + data=data, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, + label_presence=label_presence, + ) + + @property + @abc.abstractmethod + def elem_index(self): + pass + + def label_fn(self, target: Union[str, float, int]) -> Union[float, int, None]: + """ + Returns the label value usable in criterion. + + Args: + target: value of the target. + Returns: + label: value of the label usable in criterion. + """ + # Reconstruction case (no label) + if self.data.label is None: + return None + # Regression case (no label code) + elif self.data.label_code is None: + return np.float32([target]) + # Classification case (label + label_code dict) + else: + return self.data.label_code[str(target)] + + def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: + """ + Returns the label value usable in criterion. + + """ + domain_code = {"t1": 0, "flair": 1} + return domain_code[str(target)] + + def __len__(self) -> int: + return len(self.df) * self.elem_per_image + + def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: + """ + Gets the path to the tensor image (*.pt) + + Args: + participant: ID of the participant. + session: ID of the session. + cohort: Name of the cohort. + Returns: + image_path: path to the tensor containing the whole image. + """ + from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader + + # Try to find .nii.gz file + try: + folder, file_type = self.preprocessing.compute_folder_and_file_type() + + results = clinicadl_file_reader( + [participant], + [session], + self.caps_dict[cohort], + file_type.model_dump(), + ) + logger.debug(f"clinicadl_file_reader output: {results}") + filepath = Path(results[0][0]) + image_filename = filepath.name.replace(".nii.gz", ".pt") + + image_dir = ( + self.caps_dict[cohort] + / "subjects" + / participant + / session + / "deeplearning_prepare_data" + / "image_based" + / folder + ) + image_path = image_dir / image_filename + # Try to find .pt file + except ClinicaDLCAPSError: + folder, file_type = self.preprocessing.compute_folder_and_file_type() + file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt") + results = clinicadl_file_reader( + [participant], + [session], + self.caps_dict[cohort], + file_type.model_dump(), + ) + filepath = results[0] + image_path = Path(filepath[0]) + + return image_path + + def _get_meta_data( + self, idx: int + ) -> Tuple[str, str, str, int, Union[float, int, None]]: + """ + Gets all meta data necessary to compute the path with _get_image_path + + Args: + idx (int): row number of the meta-data contained in self.df + Returns: + participant (str): ID of the participant. + session (str): ID of the session. + cohort (str): Name of the cohort. + elem_index (int): Index of the part of the image. + label (str or float or int): value of the label to be used in criterion. + """ + image_idx = idx // self.elem_per_image + participant = self.df.at[image_idx, "participant_id"] + session = self.df.at[image_idx, "session_id"] + cohort = self.df.at[image_idx, "cohort"] + + if self.elem_index is None: + elem_idx = idx % self.elem_per_image + else: + elem_idx = self.elem_index + if self.label_presence and self.data.label is not None: + target = self.df.at[image_idx, self.data.label] + label = self.label_fn(target) + else: + label = -1 + + # if "domain" in self.df.columns: + # domain = self.df.at[image_idx, "domain"] + # domain = self.domain_fn(domain) + # else: + # domain = "" # TO MODIFY + return participant, session, cohort, elem_idx, label # , domain + + def _get_full_image(self) -> torch.Tensor: + """ + Allows to get the an example of the image mode corresponding to the dataset. + Useful to compute the number of elements if mode != image. + + Returns: + image tensor of the full image first image. + """ + import nibabel as nib + + from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader + + participant_id = self.df.at[0, "participant_id"] + session_id = self.df.at[0, "session_id"] + cohort = self.df.at[0, "cohort"] + + try: + image_path = self._get_image_path(participant_id, session_id, cohort) + image = torch.load(image_path, weights_only=True) + except IndexError: + file_type = self.preprocessing.file_type + results = clinicadl_file_reader( + [participant_id], + [session_id], + self.caps_dict[cohort], + file_type.model_dump(), + ) + image_nii = nib.loadsave.load(results[0]) + image_np = image_nii.get_fdata() + image = ToTensor()(image_np) + + return image + + @abc.abstractmethod + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Gets the sample containing all the information needed for training and testing tasks. + + Args: + idx: row number of the meta-data contained in self.df + Returns: + dictionary with following items: + - "image" (torch.Tensor): the input given to the model, + - "label" (int or float): the label used in criterion, + - "participant_id" (str): ID of the participant, + - "session_id" (str): ID of the session, + - f"{self.mode}_id" (int): number of the element, + - "image_path": path to the image loaded in CAPS. + + """ + pass + + @abc.abstractmethod + def num_elem_per_image(self) -> int: + """Computes the number of elements per image based on the full image.""" + pass + + def eval(self): + """Put the dataset on evaluation mode (data augmentation is not performed).""" + self.eval_mode = True + return self + + def train(self): + """Put the dataset on training mode (data augmentation is performed).""" + self.eval_mode = False + return self + + +class CapsDatasetImage(CapsDataset): + """Dataset of MRI organized in a CAPS folder.""" + + def __init__( + self, + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionImageConfig, + transforms: TransformsConfig, + label_presence: bool = True, + ): + """ + Args: + caps_directory: Directory of all the images. + data_file: Path to the tsv file or DataFrame containing the subject/session list. + train_transformations: Optional transform to be applied only on training mode. + label_presence: If True the diagnosis will be extracted from the given DataFrame. + label: Name of the column in data_df containing the label. + label_code: label code that links the output node number to label value. + all_transformations: Optional transform to be applied during training and evaluation. + multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. + + """ + + self.mode = "image" + self.label_presence = label_presence + super().__init__( + data=data, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, + label_presence=label_presence, + ) + + @property + def elem_index(self): + return None + + def __getitem__(self, idx): + participant, session, cohort, _, label = self._get_meta_data(idx) + + image_path = self._get_image_path(participant, session, cohort) + image = torch.load(image_path, weights_only=True) + + train_trf, trf = self.transforms.get_transforms() + + image = trf(image) + if self.transforms.train_transformations and not self.eval_mode: + image = train_trf(image) + + sample = CapsDatasetOutput( + image=image, + label=label, + participant_id=participant, + session_id=session, + image_id=0, + image_path=image_path, + # domain= domain, + mode=ExtractionMethod.IMAGE, + ) + + return sample + + def num_elem_per_image(self): + return 1 + + +class CapsDatasetPatch(CapsDataset): + def __init__( + self, + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionPatchConfig, + transforms: TransformsConfig, + patch_index: Optional[int] = None, + label_presence: bool = True, + ): + """ + caps_directory: Directory of all the images. + data_file: Path to the tsv file or DataFrame containing the subject/session list. + train_transformations: Optional transform to be applied only on training mode. + """ + self.patch_index = patch_index + self.label_presence = label_presence + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + super().__init__( + data=data, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, + label_presence=label_presence, + ) + + @property + def elem_index(self): + return self.patch_index + + def __getitem__(self, idx): + participant, session, cohort, patch_idx, label = self._get_meta_data(idx) + image_path = self._get_image_path(participant, session, cohort) + + if self.extraction.save_features: + patch_dir = image_path.parent.as_posix().replace( + "image_based", f"{self.extraction.extract_method}_based" + ) + patch_filename = extract_patch_path( + image_path, + self.extraction.patch_size, + self.extraction.stride_size, + patch_idx, + ) + patch_tensor = torch.load( + Path(patch_dir).resolve() / patch_filename, weights_only=True + ) + + else: + image = torch.load(image_path, weights_only=True) + patch_tensor = extract_patch_tensor( + image, + self.extraction.patch_size, + self.extraction.stride_size, + patch_idx, + ) + + train_trf, trf = self.transforms.get_transforms() + patch_tensor = trf(patch_tensor) + + if self.transforms.train_transformations and not self.eval_mode: + patch_tensor = train_trf(patch_tensor) + + sample = CapsDatasetOutput( + image=patch_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=patch_idx, + mode=ExtractionMethod.PATCH, + ) + + return sample + + def num_elem_per_image(self): + if self.elem_index is not None: + return 1 + + image = self._get_full_image() + + patches_tensor = ( + image.unfold( + 1, + self.extraction.patch_size, + self.extraction.stride_size, + ) + .unfold( + 2, + self.extraction.patch_size, + self.extraction.stride_size, + ) + .unfold( + 3, + self.extraction.patch_size, + self.extraction.stride_size, + ) + .contiguous() + ) + patches_tensor = patches_tensor.view( + -1, + self.extraction.patch_size, + self.extraction.patch_size, + self.extraction.patch_size, + ) + num_patches = patches_tensor.shape[0] + return num_patches + + +class CapsDatasetRoi(CapsDataset): + def __init__( + self, + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionROIConfig, + transforms: TransformsConfig, + roi_index: Optional[int] = None, + label_presence: bool = True, + ): + """ + Args: + caps_directory: Directory of all the images. + data_file: Path to the tsv file or DataFrame containing the subject/session list. + roi_index: If a value is given the same region will be extracted for each image. + else the dataset will load all the regions possible for one image. + train_transformations: Optional transform to be applied only on training mode. + label_presence: If True the diagnosis will be extracted from the given DataFrame. + label: Name of the column in data_df containing the label. + label_code: label code that links the output node number to label value. + all_transformations: Optional transform to be applied during training and evaluation. + multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. + + """ + self.roi_index = roi_index + self.label_presence = label_presence + + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + + self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors() + super().__init__( + data=data, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, + label_presence=label_presence, + ) + + @property + def elem_index(self): + return self.roi_index + + def __getitem__(self, idx): + participant, session, cohort, roi_idx, label = self._get_meta_data(idx) + image_path = self._get_image_path(participant, session, cohort) + + if self.extraction.roi_list is None: + raise NotImplementedError( + "Default regions are not available anymore in ClinicaDL. " + "Please define appropriate masks and give a roi_list." + ) + + if self.extraction.save_features: + mask_path = self.mask_paths[roi_idx] + roi_dir = image_path.parent.as_posix().replace( + "image_based", f"{self.extraction.extract_method}_based" + ) + roi_filename = extract_roi_path( + image_path, mask_path, self.extraction.roi_uncrop_output + ) + roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) + + else: + image = torch.load(image_path, weights_only=True) + mask_array = self.mask_arrays[roi_idx] + roi_tensor = extract_roi_tensor( + image, mask_array, self.extraction.roi_uncrop_output + ) + + train_trf, trf = self.transforms.get_transforms() + + roi_tensor = trf(roi_tensor) + + if self.transforms.train_transformations and not self.eval_mode: + roi_tensor = train_trf(roi_tensor) + + sample = CapsDatasetOutput( + image=roi_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=roi_idx, + mode=ExtractionMethod.ROI, + ) + + return sample + + def num_elem_per_image(self): + if self.elem_index is not None: + return 1 + if self.extraction.roi_list is None: + return 2 + else: + return len(self.extraction.roi_list) + + def _get_mask_paths_and_tensors( + self, + ) -> Tuple[List[Path], List]: + """Loads the masks necessary to regions extraction""" + import nibabel as nib + + caps_dict = self.caps_dict + if len(caps_dict) > 1: + caps_directory = Path(caps_dict[next(iter(caps_dict))]) + logger.warning( + f"The equality of masks is not assessed for multi-cohort training. " + f"The masks stored in {caps_directory} will be used." + ) + + try: + preprocessing_ = self.preprocessing.preprocessing + except NotImplementedError: + print( + f"Template of preprocessing {self.preprocessing.preprocessing.value} " + f"is not defined." + ) + # Find template name and pattern + if preprocessing_ == Preprocessing.CUSTOM: + template_name = self.extraction.roi_custom_template + if template_name is None: + raise ValueError( + "Please provide a name for the template when preprocessing is `custom`." + ) + + pattern = self.extraction.roi_custom_mask_pattern + if pattern is None: + raise ValueError( + "Please provide a pattern for the masks when preprocessing is `custom`." + ) + + else: + for template_ in Template: + if preprocessing_.name == template_.name: + template_name = template_ + + for pattern_ in Pattern: + if preprocessing_.name == pattern_.name: + pattern = pattern_ + + mask_location = self.data.caps_directory / "masks" / f"tpl-{template_name}" + + mask_paths, mask_arrays = list(), list() + for roi in self.extraction.roi_list: + logger.info(f"Find mask for roi {roi}.") + mask_path, desc = find_mask_path(mask_location, roi, pattern, True) + if mask_path is None: + raise FileNotFoundError(desc) + mask_nii = nib.loadsave.load(mask_path) + mask_paths.append(Path(mask_path)) + mask_arrays.append(mask_nii.get_fdata()) + + return mask_paths, mask_arrays + + +class CapsDatasetSlice(CapsDataset): + def __init__( + self, + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionSliceConfig, + transforms: TransformsConfig, + slice_index: Optional[int] = None, + label_presence: bool = True, + ): + """ + Args: + caps_directory: Directory of all the images. + data_file: Path to the tsv file or DataFrame containing the subject/session list. + slice_index: If a value is given the same slice will be extracted for each image. + else the dataset will load all the slices possible for one image. + train_transformations: Optional transform to be applied only on training mode. + label_presence: If True the diagnosis will be extracted from the given DataFrame. + label: Name of the column in data_df containing the label. + label_code: label code that links the output node number to label value. + all_transformations: Optional transform to be applied during training and evaluation. + multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. + """ + self.slice_index = slice_index + self.label_presence = label_presence + + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + + super().__init__( + data=data, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, + label_presence=label_presence, + ) + + @property + def elem_index(self): + return self.slice_index + + def __getitem__(self, idx): + participant, session, cohort, slice_idx, label = self._get_meta_data(idx) + slice_idx = slice_idx + self.extraction.discarded_slices[0] + image_path = self._get_image_path(participant, session, cohort) + + if self.extraction.save_features: + slice_dir = image_path.parent.as_posix().replace( + "image_based", f"{self.extraction.extract_method}_based" + ) + slice_filename = extract_slice_path( + image_path, + self.extraction.slice_direction, + self.extraction.slice_mode, + slice_idx, + ) + slice_tensor = torch.load( + Path(slice_dir) / slice_filename, weights_only=True + ) + + else: + image_path = self._get_image_path(participant, session, cohort) + image = torch.load(image_path, weights_only=True) + slice_tensor = extract_slice_tensor( + image, + self.extraction.slice_direction, + self.extraction.slice_mode, + slice_idx, + ) + + train_trf, trf = self.transforms.get_transforms() + + slice_tensor = trf(slice_tensor) + + if self.transforms.train_transformations and not self.eval_mode: + slice_tensor = train_trf(slice_tensor) + + sample = CapsDatasetOutput( + image=slice_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=slice_idx, + mode=ExtractionMethod.SLICE, + ) + + return sample + + def num_elem_per_image(self): + if self.elem_index is not None: + return 1 + + if self.extraction.num_slices is not None: + return self.extraction.num_slices + + image = self._get_full_image() + return ( + image.size(int(self.extraction.slice_direction) + 1) + - self.extraction.discarded_slices[0] + - self.extraction.discarded_slices[1] + ) + + +def return_dataset( + data: DataConfig, + preprocessing: PreprocessingConfig, + extraction: ExtractionConfig, + transforms_config: TransformsConfig, + cnn_index: Optional[int] = None, + label_presence: bool = True, +) -> CapsDataset: + """ + Return appropriate Dataset according to given options. + Args: + input_dir: path to a directory containing a CAPS structure. + data_df: List subjects, sessions and diagnoses. + train_transformations: Optional transform to be applied during training only. + all_transformations: Optional transform to be applied during training and evaluation. + label: Name of the column in data_df containing the label. + label_code: label code that links the output node number to label value. + cnn_index: Index of the CNN in a multi-CNN paradigm (optional). + label_presence: If True the diagnosis will be extracted from the given DataFrame. + multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. + + Returns: + the corresponding dataset. + """ + if cnn_index is not None and extraction.extract_method == ExtractionMethod.IMAGE: + raise NotImplementedError( + f"Multi-CNN is not implemented for {extraction.extract_method.value} mode." + ) + + if isinstance(extraction, ExtractionImageConfig): + return CapsDatasetImage( + data=data, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms_config, + label_presence=label_presence, + ) + + elif isinstance(extraction, ExtractionPatchConfig): + return CapsDatasetPatch( + data=data, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms_config, + label_presence=label_presence, + ) + + elif isinstance(extraction, ExtractionROIConfig): + return CapsDatasetRoi( + data=data, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms_config, + label_presence=label_presence, + ) + + elif isinstance(extraction, ExtractionSliceConfig): + return CapsDatasetSlice( + data=data, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms_config, + label_presence=label_presence, + ) + else: + raise NotImplementedError( + f"Mode {extraction.extract_method.value} is not implemented." + ) diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/caps_dataset/data_config.py index 80694fcd0..5b62373eb 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/caps_dataset/data_config.py @@ -5,6 +5,7 @@ import pandas as pd from pydantic import BaseModel, ConfigDict, computed_field, field_validator +from clinicadl.caps_dataset.utils import get_preprocessing_and_mode_from_json from clinicadl.utils.enum import Mode from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, @@ -12,7 +13,6 @@ ) from clinicadl.utils.iotools.clinica_utils import check_caps_folder from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test -from clinicadl.utils.iotools.utils import read_preprocessing logger = getLogger("clinicadl.data_config") @@ -105,61 +105,3 @@ def caps_dict(self) -> Dict[str, Path]: caps_dict = {"single": self.caps_directory} return caps_dict - - @computed_field - @property - def preprocessing_dict(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - - Raises - ------ - ValueError - In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. - """ - from clinicadl.caps_dataset.data import CapsDataset - - if self.preprocessing_json is not None: - if not self.multi_cohort: - preprocessing_json = ( - self.caps_directory / "tensor_extraction" / self.preprocessing_json - ) - else: - caps_dict = self.caps_dict - json_found = False - for caps_name, caps_path in caps_dict.items(): - preprocessing_json = ( - caps_path / "tensor_extraction" / self.preprocessing_json - ) - if preprocessing_json.is_file(): - logger.info( - f"Preprocessing JSON {preprocessing_json} found in CAPS {caps_name}." - ) - json_found = True - if not json_found: - raise ValueError( - f"Preprocessing JSON {self.preprocessing_json} was not found for any CAPS " - f"in {caps_dict}." - ) - - preprocessing_dict = read_preprocessing(preprocessing_json) - - if ( - preprocessing_dict["mode"] == "roi" - and "roi_background_value" not in preprocessing_dict - ): - preprocessing_dict["roi_background_value"] = 0 - - return preprocessing_dict - else: - return None - - @computed_field - @property - def mode(self) -> Mode: - return Mode(self.preprocessing_dict["mode"]) diff --git a/clinicadl/caps_dataset/extraction/config.py b/clinicadl/caps_dataset/extraction/config.py index 484b140e0..a52451fd1 100644 --- a/clinicadl/caps_dataset/extraction/config.py +++ b/clinicadl/caps_dataset/extraction/config.py @@ -22,7 +22,8 @@ class ExtractionConfig(BaseModel): extract_method: ExtractionMethod save_features: bool = False - extract_json: Optional[str] = None + extract_json: str + use_uncropped_image: bool = True # pydantic config model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/caps_dataset/preprocessing/config.py b/clinicadl/caps_dataset/preprocessing/config.py index 8a1718448..447d1986f 100644 --- a/clinicadl/caps_dataset/preprocessing/config.py +++ b/clinicadl/caps_dataset/preprocessing/config.py @@ -24,6 +24,7 @@ class PreprocessingConfig(BaseModel): Abstract config class for the preprocessing procedure. """ + # prepare_dl: bool = False tsv_file: Optional[Path] = None preprocessing: Preprocessing file_type: Optional[FileType] = None diff --git a/clinicadl/caps_dataset/utils.py b/clinicadl/caps_dataset/utils.py index be1be0953..207af033f 100644 --- a/clinicadl/caps_dataset/utils.py +++ b/clinicadl/caps_dataset/utils.py @@ -42,25 +42,16 @@ def get_preprocessing_and_mode_from_json(json_path: Path): Tuple[Preprocessing, SliceMode] The preprocessing and mode extracted from the json file. """ - from clinicadl.utils.iotools.utils import read_preprocessing + from clinicadl.utils.iotools.utils import read_json - preprocessing_dict = read_preprocessing(json_path) - preprocessing = Preprocessing(preprocessing_dict["preprocessing"]) - mode = ExtractionMethod(preprocessing_dict["mode"]) - return get_preprocessing(preprocessing)(**preprocessing_dict), get_extraction(mode)( - **preprocessing_dict - ) + dict_ = read_json(json_path) + return get_preprocessing_and_mode_from_parameters(**dict_) def get_preprocessing_and_mode_from_parameters(**kwargs): """ Extracts the preprocessing and mode from a json file. - Parameters - ---------- - json_path : Path - Path to the json file containing the preprocessing and mode. - Returns ------- Tuple[Preprocessing, SliceMode] diff --git a/clinicadl/commandline/pipelines/generate/shepplogan/cli.py b/clinicadl/commandline/pipelines/generate/shepplogan/cli.py index e9d20d9dc..37a9209c8 100644 --- a/clinicadl/commandline/pipelines/generate/shepplogan/cli.py +++ b/clinicadl/commandline/pipelines/generate/shepplogan/cli.py @@ -6,6 +6,8 @@ import torch from joblib import Parallel, delayed +from clinicadl.caps_dataset.extraction.config import ExtractionSliceConfig +from clinicadl.caps_dataset.preprocessing.config import CustomPreprocessingConfig from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import data, dataloader from clinicadl.commandline.pipelines.generate.shepplogan import options as shepplogan @@ -14,6 +16,11 @@ generate_shepplogan_phantom, write_missing_mods, ) +from clinicadl.utils.enum import ( + ExtractionMethod, + SliceDirection, + SliceMode, +) from clinicadl.utils.iotools.clinica_utils import FileType from clinicadl.utils.iotools.iotools import check_and_clean, commandline_to_json from clinicadl.utils.iotools.utils import write_preprocessing @@ -119,24 +126,30 @@ def create_shepplogan_image( # Save data data_df.to_csv(generated_caps_directory / "data.tsv", sep="\t", index=False) - # Save preprocessing JSON file - preprocessing_dict = { - "preprocessing": "custom", - "mode": "slice", - "use_uncropped_image": False, - "prepare_dl": True, - "extract_json": generate_config.extract_json, - "slice_direction": 2, - "slice_mode": "single", - "discarded_slices": 0, - "num_slices": 1, - "file_type": FileType( + preprocessing = CustomPreprocessingConfig( + file_type=FileType( pattern="*_space-SheppLogan_phantom.nii.gz", description="Custom suffix", needed_pipeline="shepplogan", - ).model_dump(), - } - write_preprocessing(preprocessing_dict, generated_caps_directory) + ), + use_uncropped_image=False, + ) + extraction = ExtractionSliceConfig( + slice_direction=SliceDirection.AXIAL, + slice_mode=SliceMode.SINGLE, + discarded_slices=0, + num_slices=1, + ) + + # prepare_dl = True + + # Save preprocessing JSON file + + write_preprocessing( + preprocessing=preprocessing, + extraction=extraction, + caps_directory=generated_caps_directory, + ) write_missing_mods(generated_caps_directory, data_df) logger.info(f"Shepplogan dataset was generated at {generated_caps_directory}") diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 10550a021..1da78b29a 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -154,7 +154,6 @@ def _check_args(self, parameters): mandatory_arguments = [ "caps_directory", "tsv_path", - "preprocessing_dict", "mode", "network_task", ] @@ -216,14 +215,22 @@ def _check_args(self, parameters): self.network_task, train_df, self.label ) + from clinicadl.caps_dataset.utils import ( + get_preprocessing_and_mode_from_parameters, + ) + + preprocessing, extraction = get_preprocessing_and_mode_from_parameters( + **parameters + ) full_dataset = return_dataset( - self.caps_directory, - train_df, - self.preprocessing_dict, + input_dir=self.caps_directory, + data_df=train_df, + preprocessing=preprocessing, + extraction=extraction, + transforms=transfo_config, multi_cohort=self.multi_cohort, label=self.label, label_code=self.parameters["label_code"], - transforms_config=transfo_config, ) self.parameters.update( { diff --git a/clinicadl/predictor/predictor.py b/clinicadl/predictor/predictor.py index 30fbbe5b8..2a08e8daa 100644 --- a/clinicadl/predictor/predictor.py +++ b/clinicadl/predictor/predictor.py @@ -152,11 +152,17 @@ def _predict_single( assert isinstance(self._config, PredictConfig) # assert self._config.data.label + from clinicadl.caps_dataset.utils import get_preprocessing_and_mode_from_json + + preprocessing, extraction = get_preprocessing_and_mode_from_json( + self._config.maps_manager.maps_dir / "maps.json" + ) data_test = return_dataset( - group_parameters["caps_directory"], - group_df, - self.maps_manager.preprocessing_dict, - transforms_config=self._config.transforms, + input_dir=group_parameters["caps_directory"], + data_df=group_df, + preprocessing=preprocessing, + extraction=extraction, + transforms=self._config.transforms, multi_cohort=group_parameters["multi_cohort"], label_presence=self._config.data.use_labels, label=self._config.data.label, @@ -398,11 +404,20 @@ def interpret(self): df_group, parameters_group = self.get_group_info( self._config.maps_manager.data_group, split ) + from clinicadl.caps_dataset.utils import ( + get_preprocessing_and_mode_from_json, + ) + + preprocessing, extraction = get_preprocessing_and_mode_from_json( + self._config.maps_manager.maps_dir / "maps.json" + ) + data_test = return_dataset( - parameters_group["caps_directory"], - df_group, - self.maps_manager.preprocessing_dict, - transforms_config=transforms, + input_dir=parameters_group["caps_directory"], + data_df=df_group, + preprocessing=preprocessing, + extraction=extraction, + transforms=transforms, multi_cohort=parameters_group["multi_cohort"], label_presence=False, label_code=self.maps_manager.label_code, diff --git a/clinicadl/prepare_data/prepare_data.py b/clinicadl/prepare_data/prepare_data.py index 41df526c3..d60537868 100644 --- a/clinicadl/prepare_data/prepare_data.py +++ b/clinicadl/prepare_data/prepare_data.py @@ -226,6 +226,8 @@ def prepare_roi(file): # Save parameters dictionary preprocessing_json_path = write_preprocessing( - config.extraction.model_dump(), config.data.caps_directory + preprocessing=config.preprocessing, + extraction=config.extraction, + caps_directory=config.data.caps_directory, ) logger.info(f"Preprocessing JSON saved at {preprocessing_json_path}.") diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index f8f3bca9a..57a06ef90 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -7,7 +7,7 @@ from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ClinicaDLConfigurationError from clinicadl.utils.iotools.train_utils import extract_config_from_toml_file -from clinicadl.utils.iotools.utils import path_decoder, read_preprocessing +from clinicadl.utils.iotools.utils import path_decoder, read_json def get_space_dict(launch_directory: Path) -> Dict[str, Any]: @@ -61,7 +61,7 @@ def get_space_dict(launch_directory: Path) -> Dict[str, Any]: / space_dict.pop("preprocessing_json") ) - preprocessing_dict = read_preprocessing(preprocessing_json) + preprocessing_dict = read_json(preprocessing_json) train_default["preprocessing_dict"] = preprocessing_dict train_default["mode"] = preprocessing_dict["mode"] diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 775ecd2c6..7551f0241 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -358,11 +358,18 @@ def get_dataloader( num_replicas: Optional[int] = None, homemade_sampler: bool = False, ): + from clinicadl.caps_dataset.utils import get_preprocessing_and_mode_from_json + + preprocessing, extraction = get_preprocessing_and_mode_from_json( + self.config.maps_manager.maps_dir / "maps.json" + ) + dataset = return_dataset( input_dir=self.config.data.caps_directory, data_df=data_df, - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, + preprocessing=preprocessing, + extraction=extraction, + transforms=self.config.transforms, multi_cohort=self.config.data.multi_cohort, label=self.config.data.label, label_code=self.config.data.label_code, diff --git a/clinicadl/utils/iotools/__init__.py b/clinicadl/utils/iotools/__init__.py index f2c3432c4..b97757ff1 100644 --- a/clinicadl/utils/iotools/__init__.py +++ b/clinicadl/utils/iotools/__init__.py @@ -39,4 +39,4 @@ merge_cli_and_config_file_options, ) from .trainer_utils import create_parameters_dict, patch_to_read_json -from .utils import path_decoder, path_encoder, read_preprocessing, write_preprocessing +from .utils import path_decoder, path_encoder, write_preprocessing diff --git a/clinicadl/utils/iotools/utils.py b/clinicadl/utils/iotools/utils.py index a8aec041e..2eab4a3d8 100644 --- a/clinicadl/utils/iotools/utils.py +++ b/clinicadl/utils/iotools/utils.py @@ -1,8 +1,12 @@ import errno import json from copy import copy +from datetime import datetime from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional + +from clinicadl.caps_dataset.extraction.config import ExtractionConfig +from clinicadl.caps_dataset.preprocessing.config import PreprocessingConfig def path_encoder(obj): @@ -74,12 +78,14 @@ def path_decoder(obj): def write_preprocessing( - preprocessing_dict: Dict[str, Any], caps_directory: Path + preprocessing: PreprocessingConfig, + extraction: ExtractionConfig, + caps_directory: Path, ) -> Path: extract_dir = caps_directory / "tensor_extraction" extract_dir.mkdir(parents=True, exist_ok=True) - json_path = extract_dir / preprocessing_dict["extract_json"] + json_path = extract_dir / extraction.extract_json if json_path.is_file(): raise FileExistsError( @@ -87,12 +93,14 @@ def write_preprocessing( f"Please choose another name for your preprocessing file." ) + dict_ = preprocessing.model_dump() + dict_.update(extraction.model_dump()) with json_path.open(mode="w") as json_file: - json.dump(preprocessing_dict, json_file, default=path_encoder) + json.dump(dict_, json_file, default=path_encoder) return json_path -def read_preprocessing(json_path: Path) -> Dict[str, Any]: +def read_json(json_path: Path) -> Dict[str, Any]: if json_path.suffix != ".json": json_path = json_path.with_suffix(".json") @@ -101,7 +109,7 @@ def read_preprocessing(json_path: Path) -> Dict[str, Any]: try: with json_path.open(mode="r") as f: - preprocessing_dict = json.load(f) + dict_ = json.load(f) except IOError as e: raise IOError(f"Error reading json preprocessing file {json_path}: {e}") - return preprocessing_dict + return dict_