From feff9ec489e03e243a455f0207e754ae0211c918 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 10 Oct 2024 13:29:00 +0200 Subject: [PATCH] test for some change in splitter --- clinicadl/predictor/config.py | 5 ++ clinicadl/predictor/predictor.py | 145 ++++--------------------------- clinicadl/predictor/utils.py | 1 + clinicadl/splitter/splitter.py | 15 ++-- clinicadl/tmp_config.py | 2 +- clinicadl/trainer/trainer.py | 2 +- 6 files changed, 35 insertions(+), 135 deletions(-) diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py index 34fdd7a79..ead42d1c6 100644 --- a/clinicadl/predictor/config.py +++ b/clinicadl/predictor/config.py @@ -98,3 +98,8 @@ def adapt_with_maps_manager_info(self, maps_manager: MapsManager): size_reduction=maps_manager.size_reduction, size_reduction_factor=maps_manager.size_reduction_factor, ) + + if self.split.split is None and self.split.n_splits == 0: + from clinicadl.splitter.split_utils import find_splits + + self.split.split = find_splits(self.maps_manager.maps_dir) diff --git a/clinicadl/predictor/predictor.py b/clinicadl/predictor/predictor.py index a761bcc99..f3543edd3 100644 --- a/clinicadl/predictor/predictor.py +++ b/clinicadl/predictor/predictor.py @@ -51,70 +51,21 @@ def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: from clinicadl.splitter.config import SplitterConfig from clinicadl.splitter.splitter import Splitter - tmp = _config.data.model_dump( + self.maps_manager = MapsManager(_config.maps_manager.maps_dir) + self._config.adapt_with_maps_manager_info(self.maps_manager) + + tmp = self._config.data.model_dump( exclude=set(["preprocessing_dict", "mode", "caps_dict"]) ) - tmp.update(_config.split.model_dump()) - tmp.update(_config.validation.model_dump()) + tmp.update(self._config.split.model_dump()) + tmp.update(self._config.validation.model_dump()) self.splitter = Splitter(SplitterConfig(**tmp)) - self.maps_manager = MapsManager(_config.maps_manager.maps_dir) - self._config.adapt_with_maps_manager_info(self.maps_manager) def predict( self, label_code: Union[str, dict[str, int]] = "default", ): - """Performs the prediction task on a subset of caps_directory defined in a TSV file. - Parameters - ---------- - data_group : str - name of the data group tested. - caps_directory : Path (optional, default=None) - path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group - tsv_path : Path (optional, default=None) - path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group - split_list : List[int] (optional, default=None) - list of splits to test. Default perform prediction on all splits available. - selection_metrics : List[str] (optional, default=None) - list of selection metrics to test. - Default performs the prediction on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : List[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - use_labels : bool (optional, default=True) - If True, the labels must exist in test meta-data and metrics are computed. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - label : str (optional, default=None) - Target label used for training (if network_task in [`regression`, `classification`]). - label_code : Optional[Dict[str, int]] (optional, default="default") - dictionary linking the target values to a node number. - save_tensor : bool (optional, default=False) - If true, save the tensor predicted for reconstruction task - save_nifti : bool (optional, default=False) - If true, save the nifti associated to the prediction for reconstruction task. - save_latent_tensor : bool (optional, default=False) - If true, save the tensor from the latent space for reconstruction task. - skip_leak_check : bool (optional, default=False) - If true, skip the leak check (not recommended). - Examples - -------- - >>> _input_ - _output_ - """ + """Performs the prediction task on a subset of caps_directory defined in a TSV file.""" group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) @@ -122,6 +73,8 @@ def predict( self.maps_manager.network_task, self.maps_manager.loss ) + print(f"enter split iterato: {self.splitter.split_iterator()}") + for split in self.splitter.split_iterator(): logger.info(f"Prediction of split {split}") group_df, group_parameters = self.get_group_info( @@ -142,6 +95,7 @@ def predict( ) else: split_selection_metrics = self._config.validation.selection_metrics + print(f" split selection metrics : {split_selection_metrics}") for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path @@ -149,9 +103,11 @@ def predict( / f"best-{selection}" / self._config.maps_manager.data_group ) + print(f"tsv_dir: {tsv_dir}") tsv_pattern = f"{self._config.maps_manager.data_group}*.tsv" for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() + print("out boucle") self._config.data.check_label(self.maps_manager.label) if self.maps_manager.multi_network: for network in range(self.maps_manager.num_networks): @@ -246,6 +202,7 @@ def _predict_single( ) if self._config.maps_manager.save_tensor: logger.debug("Saving tensors") + print("save_tensor") self._compute_output_tensors( maps_manager=self.maps_manager, dataset=data_test, @@ -424,59 +381,6 @@ def _compute_output_nifti( def interpret(self): """Performs the interpretation task on a subset of caps_directory defined in a TSV file. The mean interpretation is always saved, to save the individual interpretations set save_individual to True. - Parameters - ---------- - data_group : str - Name of the data group interpreted. - name : str - Name of the interpretation procedure. - method : str - Method used for extraction (ex: gradients, grad-cam...). - caps_directory : Path (optional, default=None) - Path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group. - tsv_path : Path (optional, default=None) - Path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group. - split_list : list[int] (optional, default=None) - List of splits to interpret. Default perform interpretation on all splits available. - selection_metrics : list[str] (optional, default=None) - List of selection metrics to interpret. - Default performs the interpretation on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : list[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - target_node : int (optional, default=0) - Node from which the interpretation is computed. - save_individual : bool (optional, default=False) - If True saves the individual map of each participant / session couple. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - overwrite_name : bool (optional, default=False) - If True erase the occurrences of name. - level : int (optional, default=None) - Layer number in the convolutional part after which the feature map is chosen. - save_nifti : bool (optional, default=False) - If True, save the interpretation map in nifti format. - Raises - ------ - NotImplementedError - If the method is not implemented - NotImplementedError - If the interpretaion of multi network is asked - MAPSError - If the interpretation has already been determined. """ assert isinstance(self._config, InterpretConfig) @@ -603,22 +507,6 @@ def _check_data_group( Parameters ---------- - data_group : str - name of the data group - caps_directory : str (optional, default=None) - input CAPS directory - df : pd.DataFrame (optional, default=None) - Table of participant_id / session_id of the data group - multi_cohort : bool (optional, default=False) - indicates if the input data comes from several CAPS - overwrite : bool (optional, default=False) - If True former definition of data group is erased - label : str (optional, default=None) - label name if applicable - split_list : list[int] (optional, default=None) - _description_ - skip_leak_check : bool (optional, default=False) - _description_ Raises ------ @@ -636,13 +524,15 @@ def _check_data_group( / self._config.maps_manager.data_group ) logger.debug(f"Group path {group_dir}") + print(f"group_dir: {group_dir}") if group_dir.is_dir(): # Data group already exists + print("is dir") if self._config.maps_manager.overwrite: if self._config.maps_manager.data_group in ["train", "validation"]: raise MAPSError("Cannot overwrite train or validation data group.") else: - # if not split_list: - # split_list = self.maps_manager.find_splits() + if not self._config.split.split: + self._config.split.split = self.maps_manager.find_splits() assert self._config.split for split in self._config.split.split: selection_metrics = find_selection_metrics( @@ -1155,6 +1045,7 @@ def _test_loader( if cluster.master: # Replace here + print("before saving") maps_manager._mode_level_to_tsv( prediction_df, metrics, diff --git a/clinicadl/predictor/utils.py b/clinicadl/predictor/utils.py index c66372764..0ee467956 100644 --- a/clinicadl/predictor/utils.py +++ b/clinicadl/predictor/utils.py @@ -36,6 +36,7 @@ def get_prediction( prediction_dir = ( maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) + print(prediction_dir) if not prediction_dir.is_dir(): raise MAPSError( f"No prediction corresponding to data group {data_group} was found." diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/splitter.py index 3bbdde461..a95881b41 100644 --- a/clinicadl/splitter/splitter.py +++ b/clinicadl/splitter/splitter.py @@ -14,7 +14,7 @@ class Splitter: def __init__( self, config: SplitterConfig, - split_list: Optional[List[int]] = None, + # split_list: Optional[List[int]] = None, ): """_summary_ @@ -29,7 +29,7 @@ def __init__( """ self.config = config - self.split_list = split_list + # self.config.split.split = split_list self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? @@ -38,10 +38,10 @@ def max_length(self) -> int: return self.config.split.n_splits def __len__(self): - if not self.split_list: + if not self.config.split.split: return self.config.split.n_splits else: - return len(self.split_list) + return len(self.config.split.split) @property def allowed_splits_list(self): @@ -203,10 +203,13 @@ def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: def split_iterator(self): """Returns an iterable to iterate on all splits wanted.""" - if not self.split_list: + print(self.config.split.split) + print(self.config.split.n_splits) + print(self.config.split.split) + if not self.config.split.split: return range(self.config.split.n_splits) else: - return self.split_list + return self.config.split.split def _check_item(self, item): if item not in self.allowed_splits_list: diff --git a/clinicadl/tmp_config.py b/clinicadl/tmp_config.py index 84ba18de7..620db133e 100644 --- a/clinicadl/tmp_config.py +++ b/clinicadl/tmp_config.py @@ -299,7 +299,7 @@ def adapt_cross_val_with_maps_manager_info( ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: - self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) + self.split = find_splits(maps_manager.maps_path) logger.debug(f"List of splits {self.split}") def create_groupe_df(self): diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index c1b9ed396..a949e1301 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -163,7 +163,7 @@ def resume(self, splits: List[int]) -> None: # TODO : check these two lines. Why do we need a split_manager? splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=splits) + split_manager = Splitter(splitter_config) split_iterator = split_manager.split_iterator() ###