From c0b424c803968639ee455b0076cbce5ddc8b2841 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Wed, 9 Oct 2024 10:31:30 +0200 Subject: [PATCH] tests ok ? --- clinicadl/API_test.py | 14 +- .../commandline/modules_options/validation.py | 2 +- .../commandline/pipelines/interpret/cli.py | 4 +- .../pipelines/interpret/options.py | 14 +- .../commandline/pipelines/predict/cli.py | 6 +- .../commandline/pipelines/predict/options.py | 4 +- clinicadl/interpret/config.py | 20 +- clinicadl/maps_manager/maps_manager.py | 6 +- .../{validator => predictor}/__init__.py | 0 clinicadl/predictor/config.py | 100 ++++++++ .../validator.py => predictor/predictor.py} | 218 ++++++++---------- clinicadl/{validator => predictor}/utils.py | 0 .../{validator => predictor}/validation.py | 0 clinicadl/splitter/config.py | 2 +- clinicadl/trainer/config/classification.py | 2 +- clinicadl/trainer/config/reconstruction.py | 2 +- clinicadl/trainer/config/regression.py | 2 +- clinicadl/trainer/config/train.py | 2 +- clinicadl/trainer/trainer.py | 6 +- clinicadl/validator/config.py | 40 ---- tests/test_interpret.py | 4 +- tests/test_predict.py | 8 +- .../train/trainer/test_training_config.py | 2 +- 23 files changed, 248 insertions(+), 210 deletions(-) rename clinicadl/{validator => predictor}/__init__.py (100%) create mode 100644 clinicadl/predictor/config.py rename clinicadl/{validator/validator.py => predictor/predictor.py} (89%) rename clinicadl/{validator => predictor}/utils.py (100%) rename clinicadl/{validator => predictor}/validation.py (100%) delete mode 100644 clinicadl/validator/config.py diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index 0581b879a..d144c1597 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -2,7 +2,11 @@ from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.caps_dataset.data import return_dataset +from clinicadl.predictor.config import PredictConfig +from clinicadl.predictor.predictor import Predictor from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.config.classification import ClassificationConfig from clinicadl.trainer.trainer import Trainer from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task @@ -27,11 +31,11 @@ multi_cohort, ) -split_config = SplitConfig() +split_config = SplitterConfig() splitter = Splitter(split_config) -validator_config = ValidatorConfig() -validator = Validator(validator_config) +validator_config = PredictConfig() +validator = Predictor(validator_config) train_config = ClassificationConfig() trainer = Trainer(train_config, validator) @@ -78,6 +82,6 @@ test_loader = trainer.get_dataloader(dataset, split, network, "test", config) validator.predict(test_loader) -interpret_config = InterpretConfig(**kwargs) -predict_manager = PredictManager(interpret_config) +interpret_config = PredictConfig(**kwargs) +predict_manager = Predictor(interpret_config) predict_manager.interpret() diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py index 9d26fc09f..089357866 100644 --- a/clinicadl/commandline/modules_options/validation.py +++ b/clinicadl/commandline/modules_options/validation.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.validator.validation import ValidationConfig +from clinicadl.predictor.validation import ValidationConfig # Validation valid_longitudinal = click.option( diff --git a/clinicadl/commandline/pipelines/interpret/cli.py b/clinicadl/commandline/pipelines/interpret/cli.py index eaa7d5846..160839442 100644 --- a/clinicadl/commandline/pipelines/interpret/cli.py +++ b/clinicadl/commandline/pipelines/interpret/cli.py @@ -10,7 +10,7 @@ ) from clinicadl.commandline.pipelines.interpret import options from clinicadl.interpret.config import InterpretConfig -from clinicadl.validator.validator import Validator +from clinicadl.predictor.predictor import Predictor @click.command("interpret", no_args_is_help=True) @@ -42,7 +42,7 @@ def cli(**kwargs): """ interpret_config = InterpretConfig(**kwargs) - predict_manager = Validator(interpret_config) + predict_manager = Predictor(interpret_config) predict_manager.interpret() diff --git a/clinicadl/commandline/pipelines/interpret/options.py b/clinicadl/commandline/pipelines/interpret/options.py index 5313b4a90..43cada4c4 100644 --- a/clinicadl/commandline/pipelines/interpret/options.py +++ b/clinicadl/commandline/pipelines/interpret/options.py @@ -2,28 +2,28 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.interpret.config import InterpretConfig +from clinicadl.interpret.config import InterpretBaseConfig # interpret specific name = click.argument( "name", - type=get_type("name", InterpretConfig), + type=get_type("name", InterpretBaseConfig), ) method = click.argument( "method", - type=get_type("method", InterpretConfig), # ["gradients", "grad-cam"] + type=get_type("method", InterpretBaseConfig), # ["gradients", "grad-cam"] ) level = click.option( "--level_grad_cam", - type=get_type("level", InterpretConfig), - default=get_default("level", InterpretConfig), + type=get_type("level", InterpretBaseConfig), + default=get_default("level", InterpretBaseConfig), help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", show_default=True, ) target_node = click.option( "--target_node", - type=get_type("target_node", InterpretConfig), - default=get_default("target_node", InterpretConfig), + type=get_type("target_node", InterpretBaseConfig), + default=get_default("target_node", InterpretBaseConfig), help="Which target node the gradients explain. Default takes the first output node.", show_default=True, ) diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index 101222209..184f46ad7 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -10,8 +10,8 @@ validation, ) from clinicadl.commandline.pipelines.predict import options -from clinicadl.validator.config import PredictConfig -from clinicadl.validator.validator import Validator +from clinicadl.predictor.config import PredictConfig +from clinicadl.predictor.predictor import Predictor @click.command(name="predict", no_args_is_help=True) @@ -61,7 +61,7 @@ def cli(input_maps_directory, data_group, **kwargs): """ predict_config = PredictConfig(**kwargs) - predict_manager = Validator(predict_config) + predict_manager = Predictor(predict_config) predict_manager.predict() diff --git a/clinicadl/commandline/pipelines/predict/options.py b/clinicadl/commandline/pipelines/predict/options.py index 8ae3cf073..cbb8980ca 100644 --- a/clinicadl/commandline/pipelines/predict/options.py +++ b/clinicadl/commandline/pipelines/predict/options.py @@ -1,13 +1,11 @@ import click from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.validator.config import PredictConfig +from clinicadl.predictor.config import PredictConfig # predict specific use_labels = click.option( "--use_labels/--no_labels", - show_default=True, - default=get_default("use_labels", PredictConfig), help="Set this option to --no_labels if your dataset does not contain ground truth labels.", ) save_tensor = click.option( diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index 7f171df14..cbcf8f4ab 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -8,10 +8,10 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod -from clinicadl.validator.validation import ValidationConfig logger = getLogger("clinicadl.interpret_config") @@ -44,13 +44,13 @@ def get_method(self) -> Gradients: raise ValueError(f"The method {self.method.value} is not implemented") -class InterpretConfig( - MapsManagerConfig, - InterpretBaseConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): +class InterpretConfig(BaseModel): """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + interpret: InterpretBaseConfig diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index ae794bea7..2d2db0b17 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -17,6 +17,7 @@ from clinicadl.metrics.utils import ( check_selection_metric, ) +from clinicadl.predictor.utils import get_prediction from clinicadl.splitter.config import SplitterConfig from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.tasks_utils import ( @@ -37,7 +38,6 @@ add_default_values, ) from clinicadl.utils.iotools.utils import path_encoder -from clinicadl.validator.utils import get_prediction logger = getLogger("clinicadl.maps_manager") level_list: List[str] = ["warning", "info", "debug"] @@ -139,10 +139,6 @@ def __getattr__(self, name): else: raise AttributeError(f"'MapsManager' object has no attribute '{name}'") - ################################### - # High-level functions templates # - ################################### - ############################### # Checks # ############################### diff --git a/clinicadl/validator/__init__.py b/clinicadl/predictor/__init__.py similarity index 100% rename from clinicadl/validator/__init__.py rename to clinicadl/predictor/__init__.py diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py new file mode 100644 index 000000000..34fdd7a79 --- /dev/null +++ b/clinicadl/predictor/config.py @@ -0,0 +1,100 @@ +from logging import getLogger +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, computed_field + +from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig +from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.maps_manager.config import ( + MapsManagerConfig as MapsManagerBaseConfig, +) +from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.predictor.validation import ValidationConfig +from clinicadl.splitter.config import SplitConfig +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.computational.computational import ComputationalConfig +from clinicadl.utils.enum import Task +from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore + +logger = getLogger("clinicadl.predict_config") + + +class MapsManagerConfig(MapsManagerBaseConfig): + save_tensor: bool = False + save_latent_tensor: bool = False + + def check_output_saving_tensor(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_tensor and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." + ) + + +class DataConfig(DataBaseConfig): + use_labels: bool = True + + +class PredictConfig(BaseModel): + """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + transforms: TransformsConfig + + model_config = ConfigDict(validate_assignment=True) + + def __init__(self, **kwargs): + super().__init__( + maps_manager=kwargs, + computational=kwargs, + dataloader=kwargs, + data=kwargs, + split=kwargs, + validation=kwargs, + transforms=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) + self.maps_manager.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.computational.__dict__.update(config_dict) + self.dataloader.__dict__.update(config_dict) + self.transforms.__dict__.update(config_dict) + + def adapt_with_maps_manager_info(self, maps_manager: MapsManager): + self.maps_manager.check_output_saving_nifti(maps_manager.network_task) + self.data.diagnoses = ( + maps_manager.diagnoses + if self.data.diagnoses is None or len(self.data.diagnoses) == 0 + else self.data.diagnoses + ) + + self.dataloader.batch_size = ( + maps_manager.batch_size + if not self.dataloader.batch_size + else self.dataloader.batch_size + ) + self.dataloader.n_proc = ( + maps_manager.n_proc + if not self.dataloader.n_proc + else self.dataloader.n_proc + ) + + self.split.adapt_cross_val_with_maps_manager_info(maps_manager) + self.maps_manager.check_output_saving_tensor(maps_manager.network_task) + + self.transforms = TransformsConfig( + normalize=maps_manager.normalize, + data_augmentation=maps_manager.data_augmentation, + size_reduction=maps_manager.size_reduction, + size_reduction_factor=maps_manager.size_reduction_factor, + ) diff --git a/clinicadl/validator/validator.py b/clinicadl/predictor/predictor.py similarity index 89% rename from clinicadl/validator/validator.py rename to clinicadl/predictor/predictor.py index c43fafc29..025c91abc 100644 --- a/clinicadl/validator/validator.py +++ b/clinicadl/predictor/predictor.py @@ -23,6 +23,7 @@ find_selection_metrics, ) from clinicadl.network.network import Network +from clinicadl.predictor.config import PredictConfig from clinicadl.trainer.tasks_utils import ( columns, compute_metrics, @@ -38,17 +39,18 @@ ClinicaDLDataLeakageError, MAPSError, ) -from clinicadl.validator.config import PredictConfig logger = getLogger("clinicadl.predict_manager") level_list: List[str] = ["warning", "info", "debug"] -class Validator: +class Predictor: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: - self.maps_manager = MapsManager(_config.maps_dir) self._config = _config + self.maps_manager = MapsManager(_config.maps_manager.maps_dir) + self._config.adapt_with_maps_manager_info(self.maps_manager) + def predict( self, label_code: Union[str, dict[str, int]] = "default", @@ -105,78 +107,49 @@ def predict( _output_ """ - assert isinstance(self._config, PredictConfig) - - self._config.check_output_saving_nifti(self.maps_manager.network_task) - self._config.diagnoses = ( - self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses - ) - - self._config.batch_size = ( - self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size - ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc - ) - - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) - self._config.check_output_saving_tensor(self.maps_manager.network_task) - - transforms = TransformsConfig( - normalize=self.maps_manager.normalize, - data_augmentation=self.maps_manager.data_augmentation, - size_reduction=self.maps_manager.size_reduction, - size_reduction_factor=self.maps_manager.size_reduction_factor, - ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) criterion = get_criterion( self.maps_manager.network_task, self.maps_manager.loss ) - self._check_data_group(df=group_df) - - assert self._config.split # don't know if needed ? try to raise an exception ? - # assert self._config.label - for split in self._config.split: + for split in self._config.split.split: logger.info(f"Prediction of split {split}") group_df, group_parameters = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) # Find label code if not given - if self._config.is_given_label_code(self.maps_manager.label, label_code): + if self._config.data.is_given_label_code( + self.maps_manager.label, label_code + ): generate_label_code( - self.maps_manager.network_task, group_df, self._config.label + self.maps_manager.network_task, group_df, self._config.data.label ) # Erase previous TSV files on master process - if not self._config.selection_metrics: + if not self._config.validation.selection_metrics: split_selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) else: - split_selection_metrics = self._config.selection_metrics + split_selection_metrics = self._config.validation.selection_metrics for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) - tsv_pattern = f"{self._config.data_group}*.tsv" + tsv_pattern = f"{self._config.maps_manager.data_group}*.tsv" for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() - self._config.check_label(self.maps_manager.label) + self._config.data.check_label(self.maps_manager.label) if self.maps_manager.multi_network: for network in range(self.maps_manager.num_networks): self._predict_single( group_parameters, group_df, - transforms, + self._config.transforms, label_code, criterion, split, @@ -187,7 +160,7 @@ def predict( self._predict_single( group_parameters, group_df, - transforms, + self._config.transforms, label_code, criterion, split, @@ -196,11 +169,11 @@ def predict( if cluster.master: self._ensemble_prediction( self.maps_manager, - self._config.data_group, + self._config.maps_manager.data_group, split, - self._config.selection_metrics, - self._config.use_labels, - self._config.skip_leak_check, + self._config.validation.selection_metrics, + self._config.data.use_labels, + self._config.validation.skip_leak_check, ) def _predict_single( @@ -217,16 +190,16 @@ def _predict_single( """_summary_""" assert isinstance(self._config, PredictConfig) - # assert self._config.label + # assert self._config.data.label data_test = return_dataset( group_parameters["caps_directory"], group_df, self.maps_manager.preprocessing_dict, - transforms_config=transforms, + transforms_config=self._config.transforms, multi_cohort=group_parameters["multi_cohort"], - label_presence=self._config.use_labels, - label=self._config.label, + label_presence=self._config.data.use_labels, + label=self._config.data.label, label_code=( self.maps_manager.label_code if label_code == "default" else label_code ), @@ -235,8 +208,8 @@ def _predict_single( test_loader = DataLoader( data_test, batch_size=( - self._config.batch_size - if self._config.batch_size is not None + self._config.dataloader.batch_size + if self._config.dataloader.batch_size is not None else self.maps_manager.batch_size ), shuffle=False, @@ -246,40 +219,40 @@ def _predict_single( rank=cluster.rank, shuffle=False, ), - num_workers=self._config.n_proc - if self._config.n_proc is not None + num_workers=self._config.dataloader.n_proc + if self._config.dataloader.n_proc is not None else self.maps_manager.n_proc, ) self._test_loader( maps_manager=self.maps_manager, dataloader=test_loader, criterion=criterion, - data_group=self._config.data_group, + data_group=self._config.maps_manager.data_group, split=split, selection_metrics=split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, + use_labels=self._config.data.use_labels, + gpu=self._config.computational.gpu, + amp=self._config.computational.amp, network=network, ) - if self._config.save_tensor: + if self._config.maps_manager.save_tensor: logger.debug("Saving tensors") self._compute_output_tensors( maps_manager=self.maps_manager, dataset=data_test, - data_group=self._config.data_group, + data_group=self._config.maps_manager.data_group, split=split, - selection_metrics=self._config.selection_metrics, - gpu=self._config.gpu, + selection_metrics=self._config.validation.selection_metrics, + gpu=self._config.computational.gpu, network=network, ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: self._compute_output_nifti( dataset=data_test, split=split, network=network, ) - if self._config.save_latent_tensor: + if self._config.maps_manager.save_latent_tensor: self._compute_latent_tensors( dataset=data_test, split=split, @@ -312,13 +285,13 @@ def _compute_latent_tensors( network : _type_ (optional, default=None) Index of the network tested (only used in multi-network setting). """ - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -332,7 +305,7 @@ def _compute_latent_tensors( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "latent_tensors" ) if cluster.master: @@ -389,13 +362,13 @@ def _compute_output_nifti( import nibabel as nib from numpy import eye - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -409,7 +382,7 @@ def _compute_output_nifti( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "nifti_images" ) if cluster.master: @@ -498,21 +471,24 @@ def interpret(self): """ assert isinstance(self._config, InterpretConfig) - self._config.diagnoses = ( + self._config.data.diagnoses = ( self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses + if self._config.data.diagnoses is None + or len(self._config.data.diagnoses) == 0 + else self._config.data.diagnoses ) - self._config.batch_size = ( + self._config.dataloader.batch_size = ( self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size + if not self._config.dataloader.batch_size + else self._config.dataloader.batch_size ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc + self._config.dataloader.n_proc = ( + self.maps_manager.n_proc + if not self._config.dataloader.n_proc + else self._config.dataloader.n_proc ) - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) + self._config.split.adapt_cross_val_with_maps_manager_info(self.maps_manager) if self.maps_manager.multi_network: raise NotImplementedError( @@ -524,14 +500,14 @@ def interpret(self): size_reduction=self.maps_manager.size_reduction, size_reduction_factor=self.maps_manager.size_reduction_factor, ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) assert self._config.split - for split in self._config.split: + for split in self._config.split.split: logger.info(f"Interpretation of split {split}") df_group, parameters_group = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) data_test = return_dataset( parameters_group["caps_directory"], @@ -545,22 +521,22 @@ def interpret(self): ) test_loader = DataLoader( data_test, - batch_size=self._config.batch_size, + batch_size=self._config.dataloader.batch_size, shuffle=False, - num_workers=self._config.n_proc, + num_workers=self._config.dataloader.n_proc, ) - if not self._config.selection_metrics: - self._config.selection_metrics = find_selection_metrics( + if not self._config.validation.selection_metrics: + self._config.validation.selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / f"interpret-{self._config.name}" ) if (results_path).is_dir(): @@ -576,28 +552,28 @@ def interpret(self): transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, ) - interpreter = self._config.get_method()(model) + interpreter = self._config.interpret.get_method()(model) cum_maps = [0] * data_test.elem_per_image for data in test_loader: images = data["image"].to(model.device) map_pt = interpreter.generate_gradients( images, - self._config.target_node, - level=self._config.level, - amp=self._config.amp, + self._config.interpret.target_node, + level=self._config.interpret.level, + amp=self._config.computational.amp, ) for i in range(len(data["participant_id"])): mode_id = data[f"{self.maps_manager.mode}_id"][i] cum_maps[mode_id] += map_pt[i] - if self._config.save_individual: + if self._config.interpret.save_individual: single_path = ( results_path / f"{data['participant_id'][i]}_{data['session_id'][i]}_{self.maps_manager.mode}-{data[f'{self.maps_manager.mode}_id'][i]}_map.pt" ) torch.save(map_pt[i], single_path) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -615,7 +591,7 @@ def interpret(self): mode_map, results_path / f"mean_{self.maps_manager.mode}-{i}_map.pt", ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -662,17 +638,21 @@ def _check_data_group( when caps_directory or df are not given and data group does not exist """ - group_dir = self.maps_manager.maps_path / "groups" / self._config.data_group + group_dir = ( + self.maps_manager.maps_path + / "groups" + / self._config.maps_manager.data_group + ) logger.debug(f"Group path {group_dir}") if group_dir.is_dir(): # Data group already exists - if self._config.overwrite: - if self._config.data_group in ["train", "validation"]: + if self._config.maps_manager.overwrite: + if self._config.maps_manager.data_group in ["train", "validation"]: raise MAPSError("Cannot overwrite train or validation data group.") else: # if not split_list: # split_list = self.maps_manager.find_splits() assert self._config.split - for split in self._config.split: + for split in self._config.split.split: selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, @@ -682,40 +662,40 @@ def _check_data_group( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) if results_path.is_dir(): shutil.rmtree(results_path) elif df is not None or ( - self._config.caps_directory is not None - and self._config.caps_directory != Path("") + self._config.data.caps_directory is not None + and self._config.data.caps_directory != Path("") ): raise ClinicaDLArgumentError( - f"Data group {self._config.data_group} is already defined. " + f"Data group {self._config.maps_manager.data_group} is already defined. " f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. " - f"To erase {self._config.data_group} please set overwrite to True." + f"To erase {self._config.maps_manager.data_group} please set overwrite to True." ) elif not group_dir.is_dir() and ( - self._config.caps_directory is None or df is None + self._config.data.caps_directory is None or df is None ): # Data group does not exist yet / was overwritten + missing data raise ClinicaDLArgumentError( - f"The data group {self._config.data_group} does not already exist. " + f"The data group {self._config.maps_manager.data_group} does not already exist. " f"Please specify a caps_directory and a tsv_path to create this data group." ) elif ( not group_dir.is_dir() ): # Data group does not exist yet / was overwritten + all data is provided - if self._config.skip_leak_check: + if self._config.validation.skip_leak_check: logger.info("Skipping data leakage check") else: - self._check_leakage(self._config.data_group, df) + self._check_leakage(self._config.maps_manager.data_group, df) self._write_data_group( - self._config.data_group, + self._config.maps_manager.data_group, df, - self._config.caps_directory, - self._config.multi_cohort, - label=self._config.label, + self._config.data.caps_directory, + self._config.data.multi_cohort, + label=self._config.data.label, ) def get_group_info( @@ -831,8 +811,8 @@ def _write_data_group( group_path.mkdir(parents=True) columns = ["participant_id", "session_id", "cohort"] - if self._config.label in df.columns.values: - columns += [self._config.label] + if self._config.data.label in df.columns.values: + columns += [self._config.data.label] if label is not None and label in df.columns.values: columns += [label] diff --git a/clinicadl/validator/utils.py b/clinicadl/predictor/utils.py similarity index 100% rename from clinicadl/validator/utils.py rename to clinicadl/predictor/utils.py diff --git a/clinicadl/validator/validation.py b/clinicadl/predictor/validation.py similarity index 100% rename from clinicadl/validator/validation.py rename to clinicadl/predictor/validation.py diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py index fa8834d06..59fdbaad8 100644 --- a/clinicadl/splitter/config.py +++ b/clinicadl/splitter/config.py @@ -7,8 +7,8 @@ from pydantic.types import NonNegativeInt from clinicadl.caps_dataset.data_config import DataConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.split_utils import find_splits -from clinicadl.validator.validation import ValidationConfig logger = getLogger("clinicadl.split_config") diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 6d466e4fc..f09021559 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -5,9 +5,9 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task -from clinicadl.validator.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.classification_config") diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index 7e9b151db..d4b90ee2d 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -4,6 +4,7 @@ from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ( Normalization, @@ -11,7 +12,6 @@ ReconstructionMetric, Task, ) -from clinicadl.validator.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index ee1ab4d30..f094d5552 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -5,9 +5,9 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task -from clinicadl.validator.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") logger = getLogger("clinicadl.regression_config") diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index 193cfa613..a1e949997 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -19,13 +19,13 @@ from clinicadl.network.config import NetworkConfig from clinicadl.optimizer.optimization import OptimizationConfig from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.early_stopping.config import EarlyStoppingConfig from clinicadl.utils.enum import Task -from clinicadl.validator.validation import ValidationConfig logger = getLogger("clinicadl.training_config") diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index e8970b0b3..c1b9ed396 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -33,8 +33,8 @@ patch_to_read_json, ) from clinicadl.trainer.tasks_utils import create_training_config -from clinicadl.validator.validator import Validator -from clinicadl.validator.config import PredictConfig +from clinicadl.predictor.predictor import Predictor +from clinicadl.predictor.config import PredictConfig from clinicadl.splitter.splitter import Splitter from clinicadl.splitter.config import SplitterConfig from clinicadl.transforms.config import TransformsConfig @@ -69,7 +69,7 @@ def __init__( self.maps_manager = self._init_maps_manager(config) predict_config = PredictConfig(**config.get_dict()) - self.validator = Validator(predict_config) + self.validator = Predictor(predict_config) self._check_args() def _init_maps_manager(self, config) -> MapsManager: diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py deleted file mode 100644 index 510af3fcf..000000000 --- a/clinicadl/validator/config.py +++ /dev/null @@ -1,40 +0,0 @@ -from logging import getLogger - -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.maps_manager.config import ( - MapsManagerConfig as MapsManagerBaseConfig, -) -from clinicadl.splitter.config import SplitConfig -from clinicadl.utils.computational.computational import ComputationalConfig -from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore -from clinicadl.validator.validation import ValidationConfig - -logger = getLogger("clinicadl.predict_config") - - -class MapsManagerConfig(MapsManagerBaseConfig): - save_tensor: bool = False - save_latent_tensor: bool = False - - def check_output_saving_tensor(self, network_task: str) -> None: - # Check if task is reconstruction for "save_tensor" and "save_nifti" - if self.save_tensor and network_task != "reconstruction": - raise ClinicaDLArgumentError( - "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." - ) - - -class DataConfig(DataBaseConfig): - use_labels: bool = True - - -class PredictConfig( - MapsManagerConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): - """Config class to perform Transfer Learning.""" diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 64e106db2..eea9e29c1 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -7,7 +7,7 @@ import pytest from clinicadl.interpret.config import InterpretConfig -from clinicadl.validator.validator import Validator +from clinicadl.predictor.predictor import Predictor @pytest.fixture(params=["classification", "regression"]) @@ -83,7 +83,7 @@ def run_interpret(cnn_input, tmp_out_dir, ref_dir): name=f"test-{method}", method_cls=method, ) - interpret_manager = Validator(interpret_config) + interpret_manager = Predictor(interpret_config) interpret_manager.interpret() interpret_map = interpret_manager.get_interpretation( "train", f"test-{interpret_config.method}" diff --git a/tests/test_predict.py b/tests/test_predict.py index 3ca6ea15e..e515ef41c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -7,8 +7,8 @@ import pytest from clinicadl.metrics.utils import get_metrics -from clinicadl.validator.utils import get_prediction -from clinicadl.validator.validator import Validator +from clinicadl.predictor.predictor import Predictor +from clinicadl.predictor.utils import get_prediction from .testing_tools import compare_folders, modify_maps @@ -101,7 +101,7 @@ def test_predict(cmdopt, tmp_path, test_name): # with open(json_path, "w") as f: # f.write(json_data) - from clinicadl.validator.config import PredictConfig + from clinicadl.predictor.config import PredictConfig predict_config = PredictConfig( maps_dir=model_folder, @@ -113,7 +113,7 @@ def test_predict(cmdopt, tmp_path, test_name): overwrite=True, diagnoses=["CN"], ) - predict_manager = Validator(predict_config) + predict_manager = Predictor(predict_config) predict_manager.predict() for mode in modes: diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 7280af7ad..b25dc20bb 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -7,9 +7,9 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config.ssda import SSDAConfig from clinicadl.network.config import NetworkConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig -from clinicadl.validator.validation import ValidationConfig # Tests for customed validators #