diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index c197a96de..4d6adf2fc 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -297,10 +297,10 @@ def _predict_multi( criterion=criterion, data_group=self._config.data_group, split=split, - selection_metrics=split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, + # selection_metrics=split_selection_metrics, + # use_labels=self._config.use_labels, + # gpu=self._config.gpu, + # amp=self._config.amp, network=network, ) if self._config.save_tensor: @@ -310,8 +310,6 @@ def _predict_multi( data_test, self._config.data_group, split, - self._config.selection_metrics, - gpu=self._config.gpu, network=network, ) if self._config.save_nifti: @@ -427,10 +425,10 @@ def _predict_single( criterion, self._config.data_group, split, - split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, + # split_selection_metrics, + # use_labels=self._config.use_labels, + # gpu=self._config.gpu, + # amp=self._config.amp, ) if self._config.save_tensor: logger.debug("Saving tensors") diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 37ded6ef9..0bcb2e3e4 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -228,6 +228,7 @@ def train( amp=self.maps_manager.std_amp, use_labels=use_labels, report_ci=report_ci, + selection_metrics=self.config.validation.selection_metrics, ) validator = Validator(validator_config=validator_config) @@ -526,13 +527,11 @@ def _train_single( self.maps_manager, "train", split, - self.config.validation.selection_metrics, ) validator._ensemble_prediction( self.maps_manager, "validation", split, - self.config.validation.selection_metrics, ) self._erase_tmp(split) @@ -740,13 +739,11 @@ def _train_ssda( self.maps_manager, "train", split, - self.config.validation.selection_metrics, ) validator._ensemble_prediction( self.maps_manager, "validation", split, - self.config.validation.selection_metrics, ) self._erase_tmp(split) @@ -1021,8 +1018,6 @@ def _train( criterion, "train", split, - self.config.validation.selection_metrics, - amp=self.maps_manager.std_amp, network=network, ) validator._test_loader( @@ -1031,8 +1026,6 @@ def _train( criterion, "validation", split, - self.config.validation.selection_metrics, - amp=self.maps_manager.std_amp, network=network, ) @@ -1042,7 +1035,6 @@ def _train( train_loader.dataset, "train", split, - self.config.validation.selection_metrics, nb_images=1, network=network, ) @@ -1051,7 +1043,6 @@ def _train( valid_loader.dataset, "validation", split, - self.config.validation.selection_metrics, nb_images=1, network=network, ) diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py index da1d909b2..2e6cfe92a 100644 --- a/clinicadl/validator/config.py +++ b/clinicadl/validator/config.py @@ -17,21 +17,21 @@ class ValidatorConfig(BaseModel): # maps_path: Path mode: str + metrics_module: Optional = None + report_ci: bool = False + selection_metrics: list + n_classes: int = 1 network_task: str - # model: Network - # dataloader: DataLoader - # criterion: _Loss + num_networks: int = 1 use_labels: bool = True + + gpu: Optional[bool] = None amp: bool = False fsdp: bool = False - report_ci = False - gpu: Optional[bool] = None - selection_metrics: list split_name: Optional[str] = None - num_networks: Optional[int] = None nb_unfrozen_layers: Optional[int] = None std_amp: Optional[bool] = None diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py index 04d0a4103..94a73c8da 100644 --- a/clinicadl/validator/validator.py +++ b/clinicadl/validator/validator.py @@ -50,12 +50,6 @@ def test( dataloader: DataLoader, criterion: _Loss, metrics_module: MetricModule, - # mode: str, - # n_classes: int, - # network_task, - # use_labels: bool = True, - # amp: bool = False, - # report_ci=False, ) -> Tuple[pd.DataFrame, Dict[str, float]]: """ Computes the predictions and evaluation metrics. @@ -248,10 +242,8 @@ def _compute_output_tensors( dataset, data_group, split, - selection_metrics, nb_images=None, - gpu=None, - network=None, + network: Optional[int] = None, ): """ Compute the output tensors and saves them in the MAPS. @@ -265,20 +257,20 @@ def _compute_output_tensors( gpu (bool): If given, a new value for the device of the model will be computed. network (int): Index of the network tested (only used in multi-network setting). """ - for selection_metric in selection_metrics: + for selection_metric in self.config.selection_metrics: # load the best trained model during the training model, _ = maps_manager._init_model( transfer_path=maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=gpu, + gpu=self.config.gpu, network=network, nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, ) model = DDP( model, - fsdp=maps_manager.fully_sharded_data_parallel, - amp=maps_manager.amp, + fsdp=self.config.fsdp, + amp=self.config.amp, ) model.eval() @@ -305,14 +297,14 @@ def _compute_output_tensors( data = dataset[i] image = data["image"] x = image.unsqueeze(0).to(model.device) - with autocast("cuda", enabled=maps_manager.std_amp): + with autocast("cuda", enabled=self.config.std_amp): output = model(x) output = output.squeeze(0).cpu().float() participant_id = data["participant_id"] session_id = data["session_id"] - mode_id = data[f"{maps_manager.mode}_id"] - input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" - output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" + mode_id = data[f"{self.config.mode}_id"] + input_filename = f"{participant_id}_{session_id}_{self.config.mode}-{mode_id}_input.pt" + output_filename = f"{participant_id}_{session_id}_{self.config.mode}-{mode_id}_output.pt" torch.save(image, tensor_path / input_filename) torch.save(output, tensor_path / output_filename) logger.debug(f"File saved at {[input_filename, output_filename]}") @@ -320,15 +312,13 @@ def _compute_output_tensors( def _ensemble_prediction( self, maps_manager: MapsManager, - data_group, - split, - selection_metrics, - use_labels=True, + data_group: str, + split: int, skip_leak_check=False, ): """Computes the results on the image-level.""" - if not selection_metrics: + if not self.config.selection_metrics: selection_metrics = find_selection_metrics( maps_manager.maps_path, maps_manager.split_name, split ) @@ -336,19 +326,19 @@ def _ensemble_prediction( for selection_metric in selection_metrics: ##################### # Soft voting - if maps_manager.num_networks > 1 and not skip_leak_check: + if self.config.num_networks > 1 and not skip_leak_check: maps_manager._ensemble_to_tsv( split, selection=selection_metric, data_group=data_group, - use_labels=use_labels, + use_labels=self.config.use_labels, ) - elif maps_manager.mode != "image" and not skip_leak_check: + elif self.config.mode != "image" and not skip_leak_check: maps_manager._mode_to_image_tsv( split, selection=selection_metric, data_group=data_group, - use_labels=use_labels, + use_labels=self.config.use_labels, ) def test_da(