From 574f7a0cc380ba9fd1b54f329c4dfbf13a416026 Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:23:34 +0200 Subject: [PATCH] Remove TaskManager (#648) --- clinicadl/caps_dataset/data.py | 20 +- .../commandline/pipelines/train/resume/cli.py | 2 +- clinicadl/interpret/gradients.py | 6 +- clinicadl/maps_manager/maps_manager.py | 163 +++- clinicadl/metrics/metric_module.py | 1 + clinicadl/predict/predict_manager.py | 21 +- .../quality_check/t1_linear/quality_check.py | 6 +- clinicadl/quality_check/t1_linear/utils.py | 2 +- clinicadl/trainer/__init__.py | 1 - clinicadl/trainer/config/classification.py | 2 +- clinicadl/trainer/config/reconstruction.py | 2 +- clinicadl/trainer/config/regression.py | 2 +- clinicadl/trainer/tasks_utils.py | 869 +++++++++++++++++- clinicadl/trainer/trainer.py | 219 +++-- clinicadl/utils/task_manager/__init__.py | 3 - .../utils/task_manager/classification.py | 264 ------ .../utils/task_manager/reconstruction.py | 189 ---- clinicadl/utils/task_manager/regression.py | 190 ---- clinicadl/utils/task_manager/task_manager.py | 321 ------- clinicadl/validation/cross_validation.py | 6 +- 20 files changed, 1175 insertions(+), 1114 deletions(-) delete mode 100644 clinicadl/utils/task_manager/__init__.py delete mode 100644 clinicadl/utils/task_manager/classification.py delete mode 100644 clinicadl/utils/task_manager/reconstruction.py delete mode 100644 clinicadl/utils/task_manager/regression.py delete mode 100644 clinicadl/utils/task_manager/task_manager.py diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/caps_dataset/data.py index 48d1a5480..638f49e9d 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/caps_dataset/data.py @@ -226,7 +226,7 @@ def _get_full_image(self) -> torch.Tensor: try: image_path = self._get_image_path(participant_id, session_id, cohort) - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) except IndexError: file_type = self.config.extraction.file_type results = clinicadl_file_reader( @@ -316,7 +316,7 @@ def __getitem__(self, idx): participant, session, cohort, _, label, domain = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) train_trf, trf = self.config.transforms.get_transforms() @@ -385,10 +385,12 @@ def __getitem__(self, idx): self.config.extraction.stride_size, patch_idx, ) - patch_tensor = torch.load(Path(patch_dir).resolve() / patch_filename) + patch_tensor = torch.load( + Path(patch_dir).resolve() / patch_filename, weights_only=True + ) else: - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) patch_tensor = extract_patch_tensor( image, self.config.extraction.patch_size, @@ -504,10 +506,10 @@ def __getitem__(self, idx): roi_filename = extract_roi_path( image_path, mask_path, self.config.extraction.roi_uncrop_output ) - roi_tensor = torch.load(Path(roi_dir) / roi_filename) + roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) else: - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) mask_array = self.mask_arrays[roi_idx] roi_tensor = extract_roi_tensor( image, mask_array, self.config.extraction.uncropped_roi @@ -653,11 +655,13 @@ def __getitem__(self, idx): self.config.extraction.slice_mode, slice_idx, ) - slice_tensor = torch.load(Path(slice_dir) / slice_filename) + slice_tensor = torch.load( + Path(slice_dir) / slice_filename, weights_only=True + ) else: image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) slice_tensor = extract_slice_tensor( image, self.config.extraction.slice_direction, diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index 88c4f6bc0..8734bf95d 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -4,7 +4,7 @@ from clinicadl.commandline.modules_options import ( cross_validation, ) -from clinicadl.trainer import Trainer +from clinicadl.trainer.trainer import Trainer @click.command(name="resume", no_args_is_help=True) diff --git a/clinicadl/interpret/gradients.py b/clinicadl/interpret/gradients.py index 393e64488..b62308f38 100644 --- a/clinicadl/interpret/gradients.py +++ b/clinicadl/interpret/gradients.py @@ -1,7 +1,7 @@ import abc import torch -from torch.cuda.amp import autocast +from torch.amp import autocast from clinicadl.utils.exceptions import ClinicaDLArgumentError @@ -28,7 +28,7 @@ def generate_gradients( # Forward input_batch = input_batch.to(self.device) input_batch.requires_grad = True - with autocast(enabled=amp): + with autocast("cuda", enabled=amp): if hasattr(self.model, "variational") and self.model.variational: _, _, _, model_output = self.model(input_batch) else: @@ -94,7 +94,7 @@ def generate_gradients( # Get last conv feature map feature_maps = conv_part(input_batch).detach() feature_maps.requires_grad = True - with autocast(enabled=amp): + with autocast("cuda", enabled=amp): model_output = fc_part(pre_fc_part(feature_maps)) # Target for backprop one_hot_output = torch.zeros_like(model_output) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index a2c69a035..73b6430eb 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -8,17 +8,26 @@ import pandas as pd import torch import torch.distributed as dist -from torch.cuda.amp import autocast +from torch.amp import autocast from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.caps_dataset.data import ( return_dataset, ) +from clinicadl.metrics.metric_module import MetricModule from clinicadl.metrics.utils import ( check_selection_metric, find_selection_metrics, ) from clinicadl.predict.utils import get_prediction +from clinicadl.trainer.tasks_utils import ( + ensemble_prediction, + evaluation_metrics, + generate_label_code, + output_size, + test, + test_da, +) from clinicadl.transforms.config import TransformsConfig from clinicadl.utils import cluster from clinicadl.utils.computational.ddp import DDP, init_ddp @@ -43,7 +52,7 @@ class MapsManager: def __init__( self, maps_path: Path, - parameters: Dict[str, Any] = None, + parameters: Optional[Dict[str, Any]] = None, verbose: str = "info", ): """ @@ -68,8 +77,38 @@ def __init__( ) test_parameters = self.get_parameters() # test_parameters = path_decoder(test_parameters) + # from clinicadl.trainer.task_manager import TaskConfig + self.parameters = add_default_values(test_parameters) - self.task_manager = self._init_task_manager(n_classes=self.output_size) + + ## to initialize the task parameters + + # self.task_manager = self._init_task_manager() + + self.n_classes = self.output_size + if self.network_task == "classification": + if self.n_classes is None: + self.n_classes = output_size( + self.network_task, None, self.df, self.label + ) + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=self.n_classes + ) + + elif ( + self.network_task == "regression" + or self.network_task == "reconstruction" + ): + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=None + ) + + else: + raise NotImplementedError( + f"Task {self.network_task} is not implemented in ClinicaDL. " + f"Please choose between classification, regression and reconstruction." + ) + self.split_name = ( self._check_split_wording() ) # Used only for retro-compatibility @@ -162,10 +201,14 @@ def _test_loader( ) model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp) - prediction_df, metrics = self.task_manager.test( - model, - dataloader, - criterion, + prediction_df, metrics = test( + mode=self.mode, + metrics_module=self.metrics_module, + n_classes=self.n_classes, + network_task=self.network_task, + model=model, + dataloader=dataloader, + criterion=criterion, use_labels=use_labels, amp=amp, report_ci=report_ci, @@ -241,8 +284,13 @@ def _test_loader_ssda( gpu=gpu, network=network, ) - prediction_df, metrics = self.task_manager.test_da( - model, dataloader, criterion, target=target, report_ci=report_ci + prediction_df, metrics = test_da( + self.network_task, + model, + dataloader, + criterion, + target=target, + report_ci=report_ci, ) if use_labels: if network is not None: @@ -321,7 +369,7 @@ def _compute_output_tensors( data = dataset[i] image = data["image"] x = image.unsqueeze(0).to(model.device) - with autocast(enabled=self.std_amp): + with autocast("cuda", enabled=self.std_amp): output = model(x) output = output.squeeze(0).cpu().float() participant_id = data["participant_id"] @@ -404,10 +452,33 @@ def _check_args(self, parameters): if "label" not in self.parameters: self.parameters["label"] = None - self.task_manager = self._init_task_manager(df=train_df) + from clinicadl.trainer.tasks_utils import ( + get_default_network, + ) + from clinicadl.utils.enum import Task + self.network_task = Task(self.parameters["network_task"]) + # self.task_config = TaskConfig(self.network_task, self.mode, df=train_df) + # self.task_manager = self._init_task_manager(df=train_df) + if self.network_task == "classification": + self.n_classes = output_size(self.network_task, None, train_df, self.label) + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=self.n_classes + ) + + elif self.network_task == "regression" or self.network_task == "reconstruction": + self.n_classes = None + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=None + ) + + else: + raise NotImplementedError( + f"Task {self.network_task} is not implemented in ClinicaDL. " + f"Please choose between classification, regression and reconstruction." + ) if self.parameters["architecture"] == "default": - self.parameters["architecture"] = self.task_manager.get_default_network() + self.parameters["architecture"] = get_default_network(self.network_task) if "selection_threshold" not in self.parameters: self.parameters["selection_threshold"] = None if ( @@ -415,8 +486,8 @@ def _check_args(self, parameters): or len(self.parameters["label_code"]) == 0 or self.parameters["label_code"] is None ): # Allows to set custom label code in TOML - self.parameters["label_code"] = self.task_manager.generate_label_code( - train_df, self.label + self.parameters["label_code"] = generate_label_code( + self.network_task, train_df, self.label ) full_dataset = return_dataset( @@ -431,8 +502,8 @@ def _check_args(self, parameters): self.parameters.update( { "num_networks": full_dataset.elem_per_image, - "output_size": self.task_manager.output_size( - full_dataset.size, full_dataset.df, self.label + "output_size": output_size( + self.network_task, full_dataset.size, full_dataset.df, self.label ), "input_size": full_dataset.size, } @@ -444,7 +515,7 @@ def _check_args(self, parameters): f"framework with only {self.parameters['num_networks']} element " f"per image." ) - possible_selection_metrics_set = set(self.task_manager.evaluation_metrics) | { + possible_selection_metrics_set = set(evaluation_metrics(self.network_task)) | { "loss" } if not set(self.parameters["selection_metrics"]).issubset( @@ -708,7 +779,11 @@ def _ensemble_to_tsv( performance_dir.mkdir(parents=True, exist_ok=True) - df_final, metrics = self.task_manager.ensemble_prediction( + df_final, metrics = ensemble_prediction( + self.mode, + self.metrics_module, + self.n_classes, + self.network_task, test_df, validation_df, selection_threshold=self.selection_threshold, @@ -839,7 +914,9 @@ def _init_model( / "tmp" / "checkpoint.pth.tar" ) - checkpoint_state = torch.load(checkpoint_path, map_location=device) + checkpoint_state = torch.load( + checkpoint_path, map_location=device, weights_only=True + ) model.load_state_dict(checkpoint_state["model"]) current_epoch = checkpoint_state["epoch"] elif transfer_path: @@ -912,29 +989,29 @@ def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): return split_class(**kwargs) - def _init_task_manager( - self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None - ): - from clinicadl.utils.task_manager import ( - ClassificationManager, - ReconstructionManager, - RegressionManager, - ) - - if self.network_task == "classification": - if n_classes is not None: - return ClassificationManager(self.mode, n_classes=n_classes) - else: - return ClassificationManager(self.mode, df=df, label=self.label) - elif self.network_task == "regression": - return RegressionManager(self.mode) - elif self.network_task == "reconstruction": - return ReconstructionManager(self.mode) - else: - raise NotImplementedError( - f"Task {self.network_task} is not implemented in ClinicaDL. " - f"Please choose between classification, regression and reconstruction." - ) + # def _init_task_manager( + # self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None + # ): + # from clinicadl.utils.task_manager import ( + # ClassificationManager, + # ReconstructionManager, + # RegressionManager, + # ) + + # if self.network_task == "classification": + # if n_classes is not None: + # return ClassificationManager(self.mode, n_classes=n_classes) + # else: + # return ClassificationManager(self.mode, df=df, label=self.label) + # elif self.network_task == "regression": + # return RegressionManager(self.mode) + # elif self.network_task == "reconstruction": + # return ReconstructionManager(self.mode) + # else: + # raise NotImplementedError( + # f"Task {self.network_task} is not implemented in ClinicaDL. " + # f"Please choose between classification, regression and reconstruction." + # ) ############################### # Getters # @@ -1054,7 +1131,7 @@ def get_state_dict( f"selected according to best validation {selection_metric} " f"at path {model_path}." ) - return torch.load(model_path, map_location=map_location) + return torch.load(model_path, map_location=map_location, weights_only=True) @property def std_amp(self) -> bool: diff --git a/clinicadl/metrics/metric_module.py b/clinicadl/metrics/metric_module.py index 171a06792..319b4a639 100644 --- a/clinicadl/metrics/metric_module.py +++ b/clinicadl/metrics/metric_module.py @@ -2,6 +2,7 @@ from typing import Dict, List import numpy as np +from sklearn.utils import resample metric_optimum = { "MAE": "min", diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index aabd83202..879ef0e54 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -7,7 +7,7 @@ import pandas as pd import torch import torch.distributed as dist -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -21,6 +21,7 @@ find_selection_metrics, ) from clinicadl.predict.config import PredictConfig +from clinicadl.trainer.tasks_utils import generate_label_code, get_criterion from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.ddp import DDP, cluster from clinicadl.utils.exceptions import ( @@ -123,7 +124,9 @@ def predict( ) group_df = self._config.create_groupe_df() self._check_data_group(group_df) - criterion = self.maps_manager.task_manager.get_criterion(self.maps_manager.loss) + criterion = get_criterion( + self.maps_manager.network_task, self.maps_manager.loss + ) self._check_data_group(df=group_df) assert self._config.split # don't know if needed ? try to raise an exception ? @@ -136,8 +139,8 @@ def predict( ) # Find label code if not given if self._config.is_given_label_code(self.maps_manager.label, label_code): - self.maps_manager.task_manager.generate_label_code( - group_df, self._config.label + generate_label_code( + self.maps_manager.network_task, group_df, self._config.label ) # Erase previous TSV files on master process if not self._config.selection_metrics: @@ -506,7 +509,7 @@ def _compute_latent_tensors( data = dataset[i] image = data["image"] logger.debug(f"Image for latent representation {image}") - with autocast(enabled=self.maps_manager.std_amp): + with autocast("cuda", enabled=self.maps_manager.std_amp): _, latent, _ = model.module._forward( image.unsqueeze(0).to(model.device) ) @@ -580,7 +583,7 @@ def _compute_output_nifti( data = dataset[i] image = data["image"] x = image.unsqueeze(0).to(model.device) - with autocast(enabled=self.maps_manager.std_amp): + with autocast("cuda", enabled=self.maps_manager.std_amp): output = model(x) output = output.squeeze(0).detach().cpu().float() # Convert tensor to nifti image with appropriate affine @@ -1066,7 +1069,8 @@ def get_interpretation( ) if participant_id is None and session_id is None: map_pt = torch.load( - map_dir / f"mean_{self.maps_manager.mode}-{mode_id}_map.pt" + map_dir / f"mean_{self.maps_manager.mode}-{mode_id}_map.pt", + weights_only=True, ) elif participant_id is None or session_id is None: raise ValueError( @@ -1077,6 +1081,7 @@ def get_interpretation( else: map_pt = torch.load( map_dir - / f"{participant_id}_{session_id}_{self.maps_manager.mode}-{mode_id}_map.pt" + / f"{participant_id}_{session_id}_{self.maps_manager.mode}-{mode_id}_map.pt", + weights_only=True, ) return map_pt diff --git a/clinicadl/quality_check/t1_linear/quality_check.py b/clinicadl/quality_check/t1_linear/quality_check.py index f840a4583..7063c0c68 100755 --- a/clinicadl/quality_check/t1_linear/quality_check.py +++ b/clinicadl/quality_check/t1_linear/quality_check.py @@ -8,7 +8,7 @@ import pandas as pd import torch -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.utils.data import DataLoader from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig @@ -111,7 +111,7 @@ def quality_check( # Load QC model logger.debug("Loading quality check model.") - model.load_state_dict(torch.load(model_file)) + model.load_state_dict(torch.load(model_file, weights_only=True)) model.eval() if computational_config.gpu: logger.debug("Working on GPU.") @@ -153,7 +153,7 @@ def quality_check( inputs = data["image"] if computational_config.gpu: inputs = inputs.cuda() - with autocast(enabled=computational_config.amp): + with autocast("cuda", enabled=computational_config.amp): outputs = softmax(model(inputs)) # We cast back to 32bits. It should be a no-op as softmax is not eligible # to fp16 and autocast is forbidden on CPU (output would be bf16 otherwise). diff --git a/clinicadl/quality_check/t1_linear/utils.py b/clinicadl/quality_check/t1_linear/utils.py index b9f03ba3e..20d4d5462 100755 --- a/clinicadl/quality_check/t1_linear/utils.py +++ b/clinicadl/quality_check/t1_linear/utils.py @@ -83,7 +83,7 @@ def __getitem__(self, idx): ) image_path = image_dir / image_filename - image = torch.load(image_path) + image = torch.load(image_path, weights_only=True) image = self.pt_transform(image) else: image_path = clinicadl_file_reader( diff --git a/clinicadl/trainer/__init__.py b/clinicadl/trainer/__init__.py index 260e4c8d6..e69de29bb 100644 --- a/clinicadl/trainer/__init__.py +++ b/clinicadl/trainer/__init__.py @@ -1 +0,0 @@ -from .trainer import Trainer diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 6472316f1..9ac3a4aae 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Tuple +from typing import Tuple, Union from pydantic import computed_field, field_validator diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index 4ad9d5927..08728885b 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Tuple +from typing import Tuple, Union from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index d68a873f8..b19a3ba5c 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Tuple +from typing import Tuple, Union from pydantic import computed_field, field_validator diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index e6202f81f..93a652aa8 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -1,10 +1,45 @@ -# TODO: to put in trainer ? trainer_utils.py ? -from typing import Type, Union +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from pydantic import ( + BaseModel, + ConfigDict, + computed_field, + model_validator, +) +from torch import Tensor, nn +from torch.amp import autocast +from torch.nn.functional import softmax +from torch.nn.modules.loss import _Loss +from torch.utils.data import DataLoader, Sampler, sampler +from torch.utils.data.distributed import DistributedSampler +from clinicadl.caps_dataset.data import CapsDataset +from clinicadl.metrics.metric_module import MetricModule +from clinicadl.network.network import Network from clinicadl.trainer.config.train import TrainConfig -from clinicadl.utils.enum import Task +from clinicadl.utils import cluster +from clinicadl.utils.enum import ( + ClassificationLoss, + ClassificationMetric, + ReconstructionLoss, + ReconstructionMetric, + RegressionLoss, + RegressionMetric, + Task, +) +from clinicadl.utils.exceptions import ClinicaDLArgumentError + +# if network_task == Task.CLASSIFICATION: +# elif network_task == Task.REGRESSION: +# elif network_task == Task.RECONSTRUCTION: +# TODO: to put in trainer ? trainer_utils.py ? def create_training_config(task: Union[str, Task]) -> Type[TrainConfig]: """ A factory function to create a Training Config class suited for the task. @@ -28,3 +63,831 @@ def create_training_config(task: Union[str, Task]) -> Type[TrainConfig]: ReconstructionConfig as Config, ) return Config + + +# This function is not useful anymore since we introduced config class +# default network will automatically be initialized when running the task +def get_default_network(network_task: Task) -> str: + """Returns the default network to use when no architecture is specified.""" + task_network_map = { + Task.CLASSIFICATION: "Conv5_FC3", + Task.REGRESSION: "Conv5_FC3", + Task.RECONSTRUCTION: "AE_Conv5_FC3", + } + return task_network_map.get(network_task, "Unknown Task") + + +def get_criterion( + network_task: Union[str, Task], criterion: Optional[str] = None +) -> nn.Module: + """ + Gives the optimization criterion. + Must check that it is compatible with the task. + + Args: + network_task: Task type as a string or Task enum + criterion: name of the loss as written in PyTorch. + Raises: + ClinicaDLArgumentError: if the criterion is not compatible with the task. + """ + + network_task = Task(network_task) + + def validate_criterion(criterion_name: str, compatible_losses: List[str]): + if criterion_name not in compatible_losses: + raise ClinicaDLArgumentError( + f"Loss must be chosen from {compatible_losses}." + ) + return getattr(nn, criterion_name)() + + if network_task == Task.CLASSIFICATION: + compatible_losses = [e.value for e in ClassificationLoss] + return ( + nn.CrossEntropyLoss() + if criterion is None + else validate_criterion(criterion, compatible_losses) + ) + + if network_task == Task.REGRESSION: + compatible_losses = [e.value for e in RegressionLoss] + return ( + nn.MSELoss() + if criterion is None + else validate_criterion(criterion, compatible_losses) + ) + + if network_task == Task.RECONSTRUCTION: + compatible_losses = [e.value for e in ReconstructionLoss] + reconstruction_losses = { + "VAEGaussianLoss": "VAEGaussianLoss", + "VAEBernoulliLoss": "VAEBernoulliLoss", + "VAEContinuousBernoulliLoss": "VAEContinuousBernoulliLoss", + } + + if criterion in reconstruction_losses: + from clinicadl.network.vae.vae_utils import ( + VAEBernoulliLoss, + VAEContinuousBernoulliLoss, + VAEGaussianLoss, + ) + + return eval(reconstruction_losses[criterion]) + + return ( + nn.MSELoss() + if criterion is None + else validate_criterion(criterion, compatible_losses) + ) + + +def output_size( + network_task: Union[str, Task], + input_size: Optional[Sequence[int]] = None, + df: Optional[pd.DataFrame] = None, + label: Optional[str] = None, +) -> Union[int, Sequence[int]]: + """ + Computes the output_size needed to perform the task. + + Args: + input_size: size of the input. + df: meta-data of the training set. + label: name of the column containing the labels. + Returns: + output_size + """ + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + return len(generate_label_code(network_task, df, label)) + elif network_task == Task.REGRESSION: + return 1 + elif network_task == Task.RECONSTRUCTION: + return input_size + + +def generate_label_code( + network_task: Union[str, Task], df: pd.DataFrame, label: str +) -> Optional[Dict[str, int]]: + """ + Generates a label code that links the output node number to label value. + + Args: + df: meta-data of the training set. + label: name of the column containing the labels. + Returns: + label_code + """ + + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + unique_labels = sorted(set(df[label])) + return {str(key): value for value, key in enumerate(unique_labels)} + + return None + + +def evaluation_metrics(network_task: Union[str, Task]): + """ + Evaluation metrics which can be used to evaluate the task. + """ + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + x = [e.value for e in ClassificationMetric] + x.remove("loss") + return x + elif network_task == Task.REGRESSION: + x = [e.value for e in RegressionMetric] + x.remove("loss") + return x + elif network_task == Task.RECONSTRUCTION: + x = [e.value for e in ReconstructionMetric] + x.remove("loss") + return x + else: + raise ValueError("Unknown network task") + + +def test( + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task, + model: Network, + dataloader: DataLoader, + criterion: _Loss, + use_labels: bool = True, + amp: bool = False, + report_ci=False, +) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Parameters + ---------- + model: Network + The model trained. + dataloader: DataLoader + Wrapper of a CapsDataset. + criterion: _Loss + Function to calculate the loss. + use_labels: bool + If True the true_label will be written in output DataFrame + and metrics dict will be created. + amp: bool + If True, enables Pytorch's automatic mixed precision. + + Returns + ------- + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = {} + with torch.no_grad(): + for i, data in enumerate(dataloader): + # initialize the loss list to save the loss components + with autocast("cuda", enabled=amp): + outputs, loss_dict = model(data, criterion, use_labels=use_labels) + + if i == 0: + for loss_component in loss_dict.keys(): + total_loss[loss_component] = 0 + for loss_component in total_loss.keys(): + total_loss[loss_component] += loss_dict[loss_component].float() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs.float(), + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + dataframes = [None] * dist.get_world_size() + dist.gather_object(results_df, dataframes if dist.get_rank() == 0 else None, dst=0) + if dist.get_rank() == 0: + results_df = pd.concat(dataframes) + del dataframes + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + for loss_component in total_loss.keys(): + dist.reduce(total_loss[loss_component], dst=0) + loss_value = total_loss[loss_component].item() / cluster.world_size + + if report_ci: + metrics_dict["Metric_names"].append(loss_component) + metrics_dict["Metric_values"].append(loss_value) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict[loss_component] = loss_value + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + +def test_da( + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task: Union[str, Task], + model: Network, + dataloader: DataLoader, + criterion: _Loss, + alpha: float = 0, + use_labels: bool = True, + target: bool = True, + report_ci=False, +) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Args: + model: the model trained. + dataloader: wrapper of a CapsDataset. + criterion: function to calculate the loss. + use_labels: If True the true_label will be written in output DataFrame + and metrics dict will be created. + Returns: + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = 0 + with torch.no_grad(): + for i, data in enumerate(dataloader): + outputs, loss_dict = model.compute_outputs_and_loss_test( + data, criterion, alpha, target + ) + total_loss += loss_dict["loss"].item() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, mode, metrics_module, n_classes, idx, data, outputs + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + if report_ci: + metrics_dict["Metric_names"].append("loss") + metrics_dict["Metric_values"].append(total_loss) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict["loss"] = total_loss + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + +def columns(network_task: Union[str, Task], mode: str, n_classes: Optional[int] = None): + """ + List of the columns' names in the TSV file containing the predictions. + """ + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + return [ + "participant_id", + "session_id", + f"{mode}_id", + "true_label", + "predicted_label", + ] + [f"proba{i}" for i in range(n_classes)] + elif network_task == Task.REGRESSION: + return [ + "participant_id", + "session_id", + f"{mode}_id", + "true_label", + "predicted_label", + ] + elif network_task == Task.RECONSTRUCTION: + columns = ["participant_id", "session_id", f"{mode}_id"] + for metric in evaluation_metrics(network_task): + columns.append(metric) + return columns + + +def save_outputs(network_task: Union[str, Task]): + """ + Boolean value indicating if the output values should be saved as tensor for this task. + """ + + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION or network_task == Task.REGRESSION: + return False + elif network_task == Task.RECONSTRUCTION: + return True + + +def generate_test_row( + network_task: Union[str, Task], + mode: str, + metrics_module, + n_classes: int, + idx: int, + data: Dict[str, Any], + outputs: Tensor, +) -> List[List[Any]]: + """ + Computes an individual row of the prediction TSV file. + + Args: + idx: index of the individual input and output in the batch. + data: input batch generated by a DataLoader on a CapsDataset. + outputs: output batch generated by a forward pass in the model. + Returns: + list of items to be contained in a row of the prediction TSV file. + """ + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + prediction = torch.argmax(outputs[idx].data).item() + normalized_output = softmax(outputs[idx], dim=0) + return [ + [ + data["participant_id"][idx], + data["session_id"][idx], + data[f"{mode}_id"][idx].item(), + data["label"][idx].item(), + prediction, + ] + + [normalized_output[i].item() for i in range(n_classes)] + ] + + elif network_task == Task.REGRESSION: + return [ + [ + data["participant_id"][idx], + data["session_id"][idx], + data[f"{mode}_id"][idx].item(), + data["label"][idx].item(), + outputs[idx].item(), + ] + ] + elif network_task == Task.RECONSTRUCTION: + y = data["image"][idx] + y_pred = outputs[idx].cpu() + metrics = metrics_module.apply(y, y_pred, report_ci=False) + row = [ + data["participant_id"][idx], + data["session_id"][idx], + data[f"{mode}_id"][idx].item(), + ] + + for metric in evaluation_metrics(Task.RECONSTRUCTION): + row.append(metrics[metric]) + return [row] + + +def compute_metrics( + network_task: Union[str, Task], + results_df: pd.DataFrame, + metrics_module: Optional[MetricModule] = None, + report_ci: bool = False, +) -> Dict[str, float]: + """ + Compute the metrics based on the result of generate_test_row + + Args: + results_df: results generated based on _results_test_row + Returns: + dictionary of metrics + """ + + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION or network_task == Task.REGRESSION: + if metrics_module is not None: + return metrics_module.apply( + results_df.true_label.values, + results_df.predicted_label.values, + report_ci=report_ci, + ) + + elif network_task == Task.RECONSTRUCTION: + if not report_ci: + return { + metric: results_df[metric].mean() + for metric in evaluation_metrics(Task.RECONSTRUCTION) + } + + from numpy import mean as np_mean + from scipy.stats import bootstrap + + metrics = dict() + metric_names = ["Metrics"] + metric_values = ["Values"] + lower_ci_values = ["Lower bound CI"] + upper_ci_values = ["Upper bound CI"] + se_values = ["SE"] + + for metric in evaluation_metrics(Task.RECONSTRUCTION): + metric_vals = results_df[metric] + + metric_result = str(metric_vals.mean()) + + metric_vals = (metric_vals,) + # Compute confidence intervals only if there are at least two samples in the data. + if len(results_df) >= 2: + res = bootstrap( + metric_vals, + np_mean, + n_resamples=3000, + confidence_level=0.95, + method="percentile", + ) + lower_ci, upper_ci = res.confidence_interval + standard_error = res.standard_error + else: + lower_ci, upper_ci, standard_error = "N/A" + + metric_names.append(metric) + metric_values.append(metric_result) + lower_ci_values.append(lower_ci) + upper_ci_values.append(upper_ci) + se_values.append(standard_error) + + metrics["Metric_names"] = metric_names + metrics["Metric_values"] = metric_values + metrics["Lower_CI"] = lower_ci_values + metrics["Upper_CI"] = upper_ci_values + metrics["SE"] = se_values + + return metrics + + +# TODO: add function to check that the output size of the network corresponds to what is expected to +# perform the task + + +def ensemble_prediction( + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task: Union[str, Task], + performance_df: pd.DataFrame, + validation_df: pd.DataFrame, + selection_threshold: Optional[float] = None, + use_labels: bool = True, + method: Optional[str] = None, +) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Compute the results at the image-level by assembling the results on parts of the image. + + Args: + performance_df: results that need to be assembled. + validation_df: results on the validation set used to compute the performance + of each separate part of the image. + selection_threshold: with soft-voting method, allows to exclude some parts of the image + if their associated performance is too low. + use_labels: If True, metrics are computed and the label column values must be different + from None. + method: method to assemble the results. Current implementation proposes soft or hard-voting. + + Returns: + the results and metrics on the image level + """ + + if network_task == Task.CLASSIFICATION: + if method is None: + method = "soft" + return ensemble_prediction_classification( + network_task, + mode, + metrics_module, + n_classes, + performance_df, + validation_df, + selection_threshold, + use_labels, + method, + ) + elif network_task == Task.REGRESSION: + if method is None: + method = "hard" + return ensemble_prediction_regression( + network_task, + mode, + metrics_module, + n_classes, + performance_df, + validation_df, + selection_threshold, + use_labels, + method, + ) + elif network_task == Task.RECONSTRUCTION: + return None, None + + +def ensemble_prediction_regression( + network_task, + mode: str, + metrics_module: MetricModule, + n_classes: int, + performance_df, + validation_df, + selection_threshold=None, + use_labels=True, + method="hard", +): + """ + Compute the results at the image-level by assembling the results on parts of the image. + + Args: + performance_df (pd.DataFrame): results that need to be assembled. + validation_df (pd.DataFrame): results on the validation set used to compute the performance + of each separate part of the image. + selection_threshold (float): with soft-voting method, allows to exclude some parts of the image + if their associated performance is too low. + use_labels (bool): If True, metrics are computed and the label column values must be different + from None. + method (str): method to assemble the results. Current implementation proposes only hard-voting. + + Returns: + df_final (pd.DataFrame) the results on the image level + results (Dict[str, float]) the metrics on the image level + """ + + if method != "hard": + raise NotImplementedError( + f"You asked for {method} ensemble method. " + f"The only method implemented for regression is hard-voting." + ) + + n_modes = validation_df[f"{mode}_id"].nunique() + weight_series = np.ones(n_modes) + + # Sort to allow weighted average computation + performance_df.sort_values( + ["participant_id", "session_id", f"{mode}_id"], inplace=True + ) + + # Soft majority vote + df_final = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + for (subject, session), subject_df in performance_df.groupby( + ["participant_id", "session_id"] + ): + label = subject_df["true_label"].unique().item() + prediction = np.average(subject_df["predicted_label"], weights=weight_series) + row = [[subject, session, 0, label, prediction]] + row_df = pd.DataFrame(row, columns=columns(network_task, mode, n_classes)) + df_final = pd.concat([df_final, row_df]) + + if use_labels: + results = compute_metrics( + network_task, df_final, metrics_module, report_ci=False + ) + else: + results = None + + return df_final, results + + +def ensemble_prediction_classification( + network_task, + mode: str, + metrics_module: MetricModule, + n_classes: int, + performance_df, + validation_df, + selection_threshold=None, + use_labels=True, + method="soft", +): + """ + Computes hard or soft voting based on the probabilities in performance_df. Weights are computed based + on the balanced accuracies of validation_df. + + ref: S. Raschka. Python Machine Learning., 2015 + + Args: + performance_df (pd.DataFrame): Results that need to be assembled. + validation_df (pd.DataFrame): Results on the validation set used to compute the performance + of each separate part of the image. + selection_threshold (float): with soft-voting method, allows to exclude some parts of the image + if their associated performance is too low. + use_labels (bool): If True, metrics are computed and the label column values must be different + from None. + method (str): method to assemble the results. Current implementation proposes soft or hard-voting. + + Returns: + df_final (pd.DataFrame) the results on the image level + results (Dict[str, float]) the metrics on the image level + """ + + def check_prediction(row): + if row["true_label"] == row["predicted_label"]: + return 1 + else: + return 0 + + if method == "soft": + # Compute the sub-level accuracies on the validation set: + validation_df["accurate_prediction"] = validation_df.apply( + lambda x: check_prediction(x), axis=1 + ) + sub_level_accuracies = validation_df.groupby(f"{mode}_id")[ + "accurate_prediction" + ].mean() + if selection_threshold is not None: + sub_level_accuracies[sub_level_accuracies < selection_threshold] = 0 + weight_series = sub_level_accuracies / sub_level_accuracies.sum() + elif method == "hard": + n_modes = validation_df[f"{mode}_id"].nunique() + weight_series = pd.DataFrame(np.ones((n_modes, 1))) + else: + raise NotImplementedError( + f"Ensemble method {method} was not implemented. " + f"Please choose in ['hard', 'soft']." + ) + + # Sort to allow weighted average computation + performance_df.sort_values( + ["participant_id", "session_id", f"{mode}_id"], inplace=True + ) + weight_series.sort_index(inplace=True) + + # Soft majority vote + df_final = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + for (subject, session), subject_df in performance_df.groupby( + ["participant_id", "session_id"] + ): + label = subject_df["true_label"].unique().item() + proba_list = [ + np.average(subject_df[f"proba{i}"], weights=weight_series) + for i in range(n_classes) + ] + prediction = proba_list.index(max(proba_list)) + row = [[subject, session, 0, label, prediction] + proba_list] + row_df = pd.DataFrame(row, columns=columns(network_task, mode, n_classes)) + df_final = pd.concat([df_final, row_df]) + + if use_labels: + results = compute_metrics( + network_task, df_final, metrics_module, report_ci=False + ) + else: + results = None + + return df_final, results + + +def generate_sampler( + network_task: Union[str, Task], + dataset: CapsDataset, + sampler_option: str = "random", + n_bins: int = 5, + dp_degree: Optional[int] = None, + rank: Optional[int] = None, +) -> Sampler: + """ + Returns sampler according to the wanted options. + + Args: + dataset: the dataset to sample from. + sampler_option: choice of sampler. + n_bins: number of bins to use for a continuous variable (regression task). + dp_degree: the degree of data parallelism. + rank: process id within the data parallelism communicator. + Returns: + callable given to the training data loader. + """ + + def calculate_weights_classification(df): + labels = df[dataset.config.data.label].unique() + codes = {dataset.config.data.label_code[label] for label in labels} + count = np.zeros(len(codes)) + + for idx in df.index: + label = df.loc[idx, dataset.config.data.label] + key = dataset.label_fn(label) + count[key] += 1 + + weight_per_class = 1 / np.array(count) + weights = [ + weight_per_class[dataset.label_fn(label)] * dataset.elem_per_image + for label in df[dataset.config.data.label].values + ] + return weights + + def calculate_weights_regression(df): + count = np.zeros(n_bins) + values = df[dataset.config.data.label].values.astype(float) + thresholds = np.linspace(min(values), max(values), n_bins, endpoint=False) + + for idx in df.index: + label = df.loc[idx, dataset.config.data.label] + key = max(np.where(label >= thresholds)[0]) + count[key] += 1 + + weight_per_class = 1 / count + weights = [ + weight_per_class[max(np.where(label >= thresholds)[0])] + * dataset.elem_per_image + for label in df[dataset.config.data.label].values + ] + return weights + + def get_sampler(weights): + if sampler_option == "random": + if dp_degree is not None and rank is not None: + return DistributedSampler( + weights, num_replicas=dp_degree, rank=rank, shuffle=True + ) + else: + return sampler.RandomSampler(weights) + elif sampler_option == "weighted": + length = ( + len(weights) // dp_degree + int(rank < len(weights) % dp_degree) + if dp_degree and rank is not None + else len(weights) + ) + return sampler.WeightedRandomSampler(weights, length) + else: + raise NotImplementedError( + f"The option {sampler_option} for sampler is not implemented" + ) + + network_task = Task(network_task) + df = dataset.df + + if network_task == Task.CLASSIFICATION: + weights = calculate_weights_classification(df) + elif network_task == Task.REGRESSION: + weights = calculate_weights_regression(df) + elif network_task == Task.RECONSTRUCTION: + weights = [1] * len(df) * dataset.elem_per_image + else: + raise ValueError(f"Unknown network task: {network_task}") + + return get_sampler(weights) + + +# class TaskConfig(BaseModel): +# mode: str +# network_task: Task + +# # pydantic config +# model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + +# class RegressionConfig(TaskConfig): +# network_task = Task.REGRESSION + + +# class ReconstructionConfig(TaskConfig): +# network_task = Task.RECONSTRUCTION + + +# class ClassificationConfig(TaskConfig): +# network_task = Task.CLASSIFICATION + +# n_classe: Optional[int] = None +# df: Optional[pd.DataFrame] = None +# label: Optional[str] = None + +# @model_validator(mode="after") +# def model_validator(self): +# if n_classes is None: +# n_classes = output_size(Task.CLASSIFICATION, None, df, label) +# n_classes = n_classes + +# metrics_module = MetricModule( +# evaluation_metrics(network_task), n_classes=n_classes +# ) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 37da60096..16d2d88d6 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -10,7 +10,8 @@ import pandas as pd import torch import torch.distributed as dist -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler +from torch.amp import autocast from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -37,6 +38,14 @@ from clinicadl.callbacks.callbacks import Callback from clinicadl.trainer.config.train import TrainConfig +from clinicadl.trainer.tasks_utils import ( + evaluation_metrics, + generate_sampler, + get_criterion, + save_outputs, + test, + test_da, +) logger = getLogger("clinicadl.trainer") @@ -319,7 +328,8 @@ def _train_single( label=self.config.data.label, label_code=self.maps_manager.label_code, ) - train_sampler = self.maps_manager.task_manager.generate_sampler( + train_sampler = generate_sampler( + self.maps_manager.network_task, data_train, self.config.dataloader.sampler, dp_degree=cluster.world_size, @@ -444,7 +454,8 @@ def _train_multi( cnn_index=network, ) - train_sampler = self.maps_manager.task_manager.generate_sampler( + train_sampler = generate_sampler( + self.maps_manager.network_task, data_train, self.config.dataloader.sampler, dp_degree=cluster.world_size, @@ -589,8 +600,10 @@ def _train_ssda( label=self.config.data.label, label_code=self.maps_manager.label_code, ) - train_source_sampler = self.maps_manager.task_manager.generate_sampler( - data_train_source, self.config.dataloader.sampler + train_source_sampler = generate_sampler( + self.maps_manager.network_task, + data_train_source, + self.config.dataloader.sampler, ) logger.info( @@ -711,7 +724,7 @@ def _train( train_loader: DataLoader, valid_loader: DataLoader, split: int, - network: int = None, + network: Optional[int] = None, resume: bool = False, callbacks: List[Callback] = [], ): @@ -751,7 +764,9 @@ def _train( fsdp=self.config.computational.fully_sharded_data_parallel, amp=self.config.computational.amp, ) - criterion = self.maps_manager.task_manager.get_criterion(self.config.model.loss) + criterion = get_criterion( + self.maps_manager.network_task, self.config.model.loss + ) optimizer = self._init_optimizer(model, split=split, resume=resume) self.callback_handler.on_train_begin( @@ -775,7 +790,7 @@ def _train( if cluster.master: log_writer = LogWriter( self.maps_manager.maps_path, - self.maps_manager.task_manager.evaluation_metrics + ["loss"], + evaluation_metrics(self.maps_manager.network_task) + ["loss"], split, resume=resume, beginning_epoch=beginning_epoch, @@ -790,7 +805,7 @@ def _train( selection_metrics=list(self.config.validation.selection_metrics) ) - scaler = GradScaler(enabled=self.maps_manager.std_amp) + scaler = GradScaler("cuda", enabled=self.config.computational.amp) profiler = self._init_profiler() if self.config.callbacks.track_exp == "wandb": @@ -804,9 +819,6 @@ def _train( optimizer, mode="min", factor=0.1, verbose=True ) - scaler = GradScaler(enabled=self.config.computational.amp) - profiler = self._init_profiler() - while epoch < self.config.optimization.epochs and not early_stopping.step( metrics_valid["loss"] ): @@ -828,7 +840,7 @@ def _train( ) % self.config.optimization.accumulation_steps == 0 sync = nullcontext() if update else model.no_sync() with sync: - with autocast(enabled=self.maps_manager.std_amp): + with autocast("cuda", enabled=self.maps_manager.std_amp): _, loss_dict = model(data, criterion) logger.debug(f"Train loss dictionary {loss_dict}") loss = loss_dict["loss"] @@ -849,16 +861,24 @@ def _train( ): evaluation_flag = False - _, metrics_train = self.maps_manager.task_manager.test( - model, - train_loader, - criterion, + _, metrics_train = test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_loader, + criterion=criterion, amp=self.maps_manager.std_amp, ) - _, metrics_valid = self.maps_manager.task_manager.test( - model, - valid_loader, - criterion, + _, metrics_valid = test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, amp=self.maps_manager.std_amp, ) @@ -908,11 +928,25 @@ def _train( model.zero_grad(set_to_none=True) logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - _, metrics_train = self.maps_manager.task_manager.test( - model, train_loader, criterion, amp=self.maps_manager.std_amp + _, metrics_train = test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, ) - _, metrics_valid = self.maps_manager.task_manager.test( - model, valid_loader, criterion, amp=self.maps_manager.std_amp + _, metrics_valid = test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, ) model.train() @@ -983,7 +1017,7 @@ def _train( network=network, ) - if self.maps_manager.task_manager.save_outputs: + if save_outputs(self.maps_manager.network_task): self.maps_manager._compute_output_tensors( train_loader.dataset, "train", @@ -1051,7 +1085,9 @@ def _train_ssdann( transfer_selection=self.config.transfer_learning.transfer_selection_metric, ) - criterion = self.maps_manager.task_manager.get_criterion(self.config.model.loss) + criterion = get_criterion( + self.maps_manager.network_task, self.config.model.loss + ) logger.debug(f"Criterion for {self.config.network_task} is {criterion}") optimizer = self._init_optimizer(model, split=split, resume=resume) @@ -1073,7 +1109,7 @@ def _train_ssdann( log_writer = LogWriter( self.maps_manager.maps_path, - self.maps_manager.task_manager.evaluation_metrics + ["loss"], + evaluation_metrics(self.maps_manager.network_task) + ["loss"], split, resume=resume, beginning_epoch=beginning_epoch, @@ -1129,22 +1165,30 @@ def _train_ssdann( ( _, metrics_train_target, - ) = self.maps_manager.task_manager.test_da( - model, - train_target_loader, - criterion, - alpha, + ) = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_target_loader, + criterion=criterion, + alpha=alpha, target=True, ) # TO CHECK ( _, metrics_valid_target, - ) = self.maps_manager.task_manager.test_da( - model, - valid_loader, - criterion, - alpha, + ) = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + alpha=alpha, target=True, ) @@ -1173,14 +1217,28 @@ def _train_ssdann( ( _, metrics_train_source, - ) = self.maps_manager.task_manager.test_da( - model, train_source_loader, criterion, alpha + ) = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_source_loader, + criterion=criterion, + alpha=alpha, ) ( _, metrics_valid_source, - ) = self.maps_manager.task_manager.test_da( - model, valid_source_loader, criterion, alpha + ) = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_source_loader, + criterion=criterion, + alpha=alpha, ) model.train() @@ -1228,21 +1286,29 @@ def _train_ssdann( logger.info( f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." ) - _, metrics_train_source = self.maps_manager.task_manager.test_da( - model, - train_source_loader, - criterion, - alpha, - True, - False, + _, metrics_train_source = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_source_loader, + criterion=criterion, + alpha=alpha, + target=True, + report_ci=False, ) - _, metrics_valid_source = self.maps_manager.task_manager.test_da( - model, - valid_source_loader, - criterion, - alpha, - True, - False, + _, metrics_valid_source = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_source_loader, + criterion=criterion, + alpha=alpha, + target=True, + report_ci=False, ) log_writer.step( @@ -1262,18 +1328,26 @@ def _train_ssdann( f"at the end of iteration {i}" ) - _, metrics_train_target = self.maps_manager.task_manager.test_da( - model, - train_target_loader, - criterion, - alpha, + _, metrics_train_target = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_target_loader, + criterion=criterion, + alpha=alpha, target=True, ) - _, metrics_valid_target = self.maps_manager.task_manager.test_da( - model, - valid_loader, - criterion, - alpha, + _, metrics_valid_target = test_da( + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + alpha=alpha, target=True, ) @@ -1347,7 +1421,7 @@ def _train_ssdann( alpha=0, ) - if self.maps_manager.task_manager.save_outputs: + if save_outputs(self.maps_manager.network_task): self.maps_manager._compute_output_tensors( train_target_loader.dataset, "train", @@ -1392,7 +1466,7 @@ def _init_callbacks(self) -> None: def _init_optimizer( self, model: DDP, - split: int = None, + split: Optional[int] = None, resume: bool = False, ) -> torch.optim.Optimizer: """ @@ -1430,7 +1504,9 @@ def _init_optimizer( / "tmp" / "optimizer.pth.tar" ) - checkpoint_state = torch.load(checkpoint_path, map_location=model.device) + checkpoint_state = torch.load( + checkpoint_path, map_location=model.device, weights_only=True + ) model.load_optim_state_dict(optimizer, checkpoint_state["optimizer"]) return optimizer @@ -1445,7 +1521,8 @@ def _init_profiler(self) -> torch.profiler.profile: Profiler context manager. """ if self.config.optimization.profiler: - from clinicadl.utils.maps_manager.cluster.profiler import ( + # TODO: no more profiler ???? + from clinicadl.utils.cluster.profiler import ( ProfilerActivity, profile, schedule, @@ -1491,7 +1568,7 @@ def _write_weights( state: Dict[str, Any], metrics_dict: Optional[Dict[str, bool]], split: int, - network: int = None, + network: Optional[int] = None, filename: str = "checkpoint.pth.tar", save_all_models: bool = False, ) -> None: diff --git a/clinicadl/utils/task_manager/__init__.py b/clinicadl/utils/task_manager/__init__.py deleted file mode 100644 index 6264e4825..000000000 --- a/clinicadl/utils/task_manager/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .classification import ClassificationManager -from .reconstruction import ReconstructionManager -from .regression import RegressionManager diff --git a/clinicadl/utils/task_manager/classification.py b/clinicadl/utils/task_manager/classification.py deleted file mode 100644 index abaeb4985..000000000 --- a/clinicadl/utils/task_manager/classification.py +++ /dev/null @@ -1,264 +0,0 @@ -from logging import getLogger - -import numpy as np -import pandas as pd -import torch -from torch import nn -from torch.nn.functional import softmax -from torch.utils.data import sampler -from torch.utils.data.distributed import DistributedSampler - -from clinicadl.caps_dataset.data import CapsDataset -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.task_manager.task_manager import TaskManager - -logger = getLogger("clinicadl.task_manager") - - -class ClassificationManager(TaskManager): - def __init__( - self, - mode, - n_classes=None, - df=None, - label=None, - ): - if n_classes is None: - n_classes = self.output_size(None, df, label) - self.n_classes = n_classes - super().__init__(mode, n_classes) - - @property - def columns(self): - return [ - "participant_id", - "session_id", - f"{self.mode}_id", - "true_label", - "predicted_label", - ] + [f"proba{i}" for i in range(self.n_classes)] - - @property - def evaluation_metrics(self): - return [ - "BA", - "accuracy", - "F1_score", - "sensitivity", - "specificity", - "PPV", - "NPV", - "MCC", - "MK", - "LR_plus", - "LR_minus", - ] - - @property - def save_outputs(self): - return False - - def generate_test_row(self, idx, data, outputs): - prediction = torch.argmax(outputs[idx].data).item() - normalized_output = softmax(outputs[idx], dim=0) - return [ - [ - data["participant_id"][idx], - data["session_id"][idx], - data[f"{self.mode}_id"][idx].item(), - data["label"][idx].item(), - prediction, - ] - + [normalized_output[i].item() for i in range(self.n_classes)] - ] - - def compute_metrics(self, results_df, report_ci: bool = False): - return self.metrics_module.apply( - results_df.true_label.values, - results_df.predicted_label.values, - report_ci=report_ci, - ) - - @staticmethod - def generate_label_code(df, label): - unique_labels = list(set(getattr(df, label))) - unique_labels.sort() - return {str(key): value for value, key in enumerate(unique_labels)} - - @staticmethod - def output_size(input_size, df, label): - label_code = ClassificationManager.generate_label_code(df, label) - return len(label_code) - - @staticmethod - def generate_sampler( - dataset: CapsDataset, - sampler_option="random", - n_bins=5, - dp_degree=None, - rank=None, - ): - df = dataset.df - labels = df[dataset.config.data.label].unique() - codes = set() - for label in labels: - codes.add(dataset.config.data.label_code[label]) - count = np.zeros(len(codes)) - - for idx in df.index: - label = df.loc[idx, dataset.config.data.label] - key = dataset.label_fn(label) - count[key] += 1 - - weight_per_class = 1 / np.array(count) - weights = [] - - for idx, label in enumerate(df[dataset.config.data.label].values): - key = dataset.label_fn(label) - weights += [weight_per_class[key]] * dataset.elem_per_image - - if sampler_option == "random": - if dp_degree is not None and rank is not None: - return DistributedSampler( - weights, num_replicas=dp_degree, rank=rank, shuffle=True - ) - else: - return sampler.RandomSampler(weights) - elif sampler_option == "weighted": - if dp_degree is not None and rank is not None: - length = len(weights) // dp_degree + int( - rank < len(weights) % dp_degree - ) - else: - length = len(weights) - return sampler.WeightedRandomSampler(weights, length) - else: - raise NotImplementedError( - f"The option {sampler_option} for sampler on classification task is not implemented" - ) - - @staticmethod - def generate_sampler_ssda( - dataset: CapsDataset, df, sampler_option="random", n_bins=5 - ): - n_labels = df["diagnosis_train"].nunique() - count = np.zeros(n_labels) - - for idx in df.index: - label = df.loc[idx, "diagnosis_train"] - key = dataset.label_fn(label) - count[key] += 1 - - weight_per_class = 1 / np.array(count) - weights = [] - - for idx, label in enumerate(df["diagnosis_train"].values): - key = dataset.label_fn(label) - weights += [weight_per_class[key]] * dataset.elem_per_image - - if sampler_option == "random": - return sampler.RandomSampler(weights) - elif sampler_option == "weighted": - return sampler.WeightedRandomSampler(weights, len(weights)) - else: - raise NotImplementedError( - f"The option {sampler_option} for sampler on classification task is not implemented" - ) - - def ensemble_prediction( - self, - performance_df, - validation_df, - selection_threshold=None, - use_labels=True, - method="soft", - ): - """ - Computes hard or soft voting based on the probabilities in performance_df. Weights are computed based - on the balanced accuracies of validation_df. - - ref: S. Raschka. Python Machine Learning., 2015 - - Args: - performance_df (pd.DataFrame): Results that need to be assembled. - validation_df (pd.DataFrame): Results on the validation set used to compute the performance - of each separate part of the image. - selection_threshold (float): with soft-voting method, allows to exclude some parts of the image - if their associated performance is too low. - use_labels (bool): If True, metrics are computed and the label column values must be different - from None. - method (str): method to assemble the results. Current implementation proposes soft or hard-voting. - - Returns: - df_final (pd.DataFrame) the results on the image level - results (Dict[str, float]) the metrics on the image level - """ - - def check_prediction(row): - if row["true_label"] == row["predicted_label"]: - return 1 - else: - return 0 - - if method == "soft": - # Compute the sub-level accuracies on the validation set: - validation_df["accurate_prediction"] = validation_df.apply( - lambda x: check_prediction(x), axis=1 - ) - sub_level_accuracies = validation_df.groupby(f"{self.mode}_id")[ - "accurate_prediction" - ].mean() - if selection_threshold is not None: - sub_level_accuracies[sub_level_accuracies < selection_threshold] = 0 - weight_series = sub_level_accuracies / sub_level_accuracies.sum() - elif method == "hard": - n_modes = validation_df[f"{self.mode}_id"].nunique() - weight_series = pd.DataFrame(np.ones((n_modes, 1))) - else: - raise NotImplementedError( - f"Ensemble method {method} was not implemented. " - f"Please choose in ['hard', 'soft']." - ) - - # Sort to allow weighted average computation - performance_df.sort_values( - ["participant_id", "session_id", f"{self.mode}_id"], inplace=True - ) - weight_series.sort_index(inplace=True) - - # Soft majority vote - df_final = pd.DataFrame(columns=self.columns) - for (subject, session), subject_df in performance_df.groupby( - ["participant_id", "session_id"] - ): - label = subject_df["true_label"].unique().item() - proba_list = [ - np.average(subject_df[f"proba{i}"], weights=weight_series) - for i in range(self.n_classes) - ] - prediction = proba_list.index(max(proba_list)) - row = [[subject, session, 0, label, prediction] + proba_list] - row_df = pd.DataFrame(row, columns=self.columns) - df_final = pd.concat([df_final, row_df]) - - if use_labels: - results = self.compute_metrics(df_final, report_ci=False) - else: - results = None - - return df_final, results - - @staticmethod - def get_criterion(criterion=None): - compatible_losses = ["CrossEntropyLoss", "MultiMarginLoss"] - if criterion is None: - return nn.CrossEntropyLoss() - if criterion not in compatible_losses: - raise ClinicaDLArgumentError( - f"Classification loss must be chosen in {compatible_losses}." - ) - return getattr(nn, criterion)() - - @staticmethod - def get_default_network(): - return "Conv5_FC3" diff --git a/clinicadl/utils/task_manager/reconstruction.py b/clinicadl/utils/task_manager/reconstruction.py deleted file mode 100644 index deff564cf..000000000 --- a/clinicadl/utils/task_manager/reconstruction.py +++ /dev/null @@ -1,189 +0,0 @@ -from torch import nn -from torch.utils.data import sampler -from torch.utils.data.distributed import DistributedSampler - -from clinicadl.caps_dataset.data import CapsDataset -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.task_manager.task_manager import TaskManager - - -class ReconstructionManager(TaskManager): - def __init__( - self, - mode, - ): - super().__init__(mode) - - @property - def columns(self): - columns = ["participant_id", "session_id", f"{self.mode}_id"] - for metric in self.evaluation_metrics: - columns.append(metric) - return columns - - @property - def evaluation_metrics(self): - return ["MAE", "RMSE", "PSNR", "SSIM"] - - @property - def save_outputs(self): - return True - - def generate_test_row(self, idx, data, outputs): - y = data["image"][idx] - y_pred = outputs[idx].cpu() - metrics = self.metrics_module.apply(y, y_pred, report_ci=False) - row = [ - data["participant_id"][idx], - data["session_id"][idx], - data[f"{self.mode}_id"][idx].item(), - ] - - for metric in self.evaluation_metrics: - row.append(metrics[metric]) - return [row] - - def compute_metrics(self, results_df, report_ci=False): - if not report_ci: - return { - metric: results_df[metric].mean() for metric in self.evaluation_metrics - } - - from numpy import mean as np_mean - from scipy.stats import bootstrap - - metrics = dict() - metric_names = ["Metrics"] - metric_values = ["Values"] - lower_ci_values = ["Lower bound CI"] - upper_ci_values = ["Upper bound CI"] - se_values = ["SE"] - - for metric in self.evaluation_metrics: - metric_vals = results_df[metric] - - metric_result = metric_vals.mean() - - metric_vals = (metric_vals,) - # Compute confidence intervals only if there are at least two samples in the data. - if len(results_df) >= 2: - res = bootstrap( - metric_vals, - np_mean, - n_resamples=3000, - confidence_level=0.95, - method="percentile", - ) - lower_ci, upper_ci = res.confidence_interval - standard_error = res.standard_error - else: - lower_ci, upper_ci, standard_error = "N/A" - - metric_names.append(metric) - metric_values.append(metric_result) - lower_ci_values.append(lower_ci) - upper_ci_values.append(upper_ci) - se_values.append(standard_error) - - metrics["Metric_names"] = metric_names - metrics["Metric_values"] = metric_values - metrics["Lower_CI"] = lower_ci_values - metrics["Upper_CI"] = upper_ci_values - metrics["SE"] = se_values - - return metrics - - @staticmethod - def output_size(input_size, df, label): - return input_size - - @staticmethod - def generate_label_code(df, label): - return None - - @staticmethod - def generate_sampler( - dataset: CapsDataset, - sampler_option="random", - n_bins=5, - dp_degree=None, - rank=None, - ): - df = dataset.df - - weights = [1] * len(df) * dataset.elem_per_image - - if sampler_option == "random": - if dp_degree is not None and rank is not None: - return DistributedSampler( - weights, num_replicas=dp_degree, rank=rank, shuffle=True - ) - else: - return sampler.RandomSampler(weights) - else: - raise NotImplementedError( - f"The option {sampler_option} for sampler on reconstruction task is not implemented" - ) - - def ensemble_prediction( - self, - performance_df, - validation_df, - selection_threshold=None, - use_labels=True, - method="soft", - ): - """ - Do not perform any ensemble prediction as it is not possible for reconstruction. - - Args: - performance_df (pd.DataFrame): results that need to be assembled. - validation_df (pd.DataFrame): results on the validation set used to compute the performance - of each separate part of the image. - selection_threshold (float): with soft-voting method, allows to exclude some parts of the image - if their associated performance is too low. - use_labels (bool): If True, metrics are computed and the label column values must be different - from None. - method (str): method to assemble the results. Current implementation proposes soft or hard-voting. - - Returns: - None - """ - return None, None - - @staticmethod - def get_criterion(criterion=None): - compatible_losses = [ - "L1Loss", - "MSELoss", - "KLDivLoss", - "BCEWithLogitsLoss", - "HuberLoss", - "SmoothL1Loss", - "VAEGaussianLoss", - "VAEBernoulliLoss", - "VAEContinuousBernoulliLoss", - ] - if criterion is None: - return nn.MSELoss() - if criterion not in compatible_losses: - raise ClinicaDLArgumentError( - f"Reconstruction loss must be chosen in {compatible_losses}." - ) - if criterion == "VAEGaussianLoss": - from clinicadl.network.vae.vae_utils import VAEGaussianLoss - - return VAEGaussianLoss - elif criterion == "VAEBernoulliLoss": - from clinicadl.network.vae.vae_utils import VAEBernoulliLoss - - return VAEBernoulliLoss - elif criterion == "VAEContinuousBernoulliLoss": - from clinicadl.network.vae.vae_utils import VAEContinuousBernoulliLoss - - return VAEContinuousBernoulliLoss - return getattr(nn, criterion)() - - @staticmethod - def get_default_network(): - return "AE_Conv5_FC3" diff --git a/clinicadl/utils/task_manager/regression.py b/clinicadl/utils/task_manager/regression.py deleted file mode 100644 index 43ffafdf8..000000000 --- a/clinicadl/utils/task_manager/regression.py +++ /dev/null @@ -1,190 +0,0 @@ -import numpy as np -import pandas as pd -from torch import nn -from torch.utils.data import sampler -from torch.utils.data.distributed import DistributedSampler - -from clinicadl.caps_dataset.data import CapsDataset -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.task_manager.task_manager import TaskManager - - -class RegressionManager(TaskManager): - def __init__( - self, - mode, - ): - super().__init__(mode) - - @property - def columns(self): - return [ - "participant_id", - "session_id", - f"{self.mode}_id", - "true_label", - "predicted_label", - ] - - @property - def evaluation_metrics(self): - return ["R2_score", "MAE", "RMSE"] - - @property - def save_outputs(self): - return False - - def generate_test_row(self, idx, data, outputs): - return [ - [ - data["participant_id"][idx], - data["session_id"][idx], - data[f"{self.mode}_id"][idx].item(), - data["label"][idx].item(), - outputs[idx].item(), - ] - ] - - def compute_metrics(self, results_df, report_ci): - return self.metrics_module.apply( - results_df.true_label.values, - results_df.predicted_label.values, - report_ci=report_ci, - ) - - @staticmethod - def generate_label_code(df, label): - return None - - @staticmethod - def output_size(input_size, df, label): - return 1 - - @staticmethod - def generate_sampler( - dataset: CapsDataset, - sampler_option="random", - n_bins=5, - dp_degree=None, - rank=None, - ): - df = dataset.df - - count = np.zeros(n_bins) - values = df[dataset.config.data.label].values.astype(float) - thresholds = [ - min(values) + i * (max(values) - min(values)) / n_bins - for i in range(n_bins) - ] - for idx in df.index: - label = df.loc[idx, dataset.config.data.label] - key = max(np.where((label >= np.array(thresholds))[0])) - count[[key]] += 1 - weight_per_class = 1 / np.array(count) - weights = [] - - for idx, label in enumerate(df[dataset.config.data.label].values): - key = max(np.where((label >= np.array(thresholds)))[0]) - weights += [weight_per_class[key]] * dataset.elem_per_image - - if sampler_option == "random": - if dp_degree is not None and rank is not None: - return DistributedSampler( - weights, num_replicas=dp_degree, rank=rank, shuffle=True - ) - else: - return sampler.RandomSampler(weights) - elif sampler_option == "weighted": - if dp_degree is not None and rank is not None: - length = len(weights) // dp_degree + int( - rank < len(weights) % dp_degree - ) - else: - length = len(weights) - return sampler.WeightedRandomSampler(weights, length) - else: - raise NotImplementedError( - f"The option {sampler_option} for sampler on regression task is not implemented" - ) - - def ensemble_prediction( - self, - performance_df, - validation_df, - selection_threshold=None, - use_labels=True, - method="hard", - ): - """ - Compute the results at the image-level by assembling the results on parts of the image. - - Args: - performance_df (pd.DataFrame): results that need to be assembled. - validation_df (pd.DataFrame): results on the validation set used to compute the performance - of each separate part of the image. - selection_threshold (float): with soft-voting method, allows to exclude some parts of the image - if their associated performance is too low. - use_labels (bool): If True, metrics are computed and the label column values must be different - from None. - method (str): method to assemble the results. Current implementation proposes only hard-voting. - - Returns: - df_final (pd.DataFrame) the results on the image level - results (Dict[str, float]) the metrics on the image level - """ - - if method != "hard": - raise NotImplementedError( - f"You asked for {method} ensemble method. " - f"The only method implemented for regression is hard-voting." - ) - - n_modes = validation_df[f"{self.mode}_id"].nunique() - weight_series = np.ones(n_modes) - - # Sort to allow weighted average computation - performance_df.sort_values( - ["participant_id", "session_id", f"{self.mode}_id"], inplace=True - ) - - # Soft majority vote - df_final = pd.DataFrame(columns=self.columns) - for (subject, session), subject_df in performance_df.groupby( - ["participant_id", "session_id"] - ): - label = subject_df["true_label"].unique().item() - prediction = np.average( - subject_df["predicted_label"], weights=weight_series - ) - row = [[subject, session, 0, label, prediction]] - row_df = pd.DataFrame(row, columns=self.columns) - df_final = pd.concat([df_final, row_df]) - - if use_labels: - results = self.compute_metrics(df_final, report_ci=False) - else: - results = None - - return df_final, results - - @staticmethod - def get_criterion(criterion=None): - compatible_losses = [ - "L1Loss", - "MSELoss", - "KLDivLoss", - "BCEWithLogitsLoss", - "HuberLoss", - "SmoothL1Loss", - ] - if criterion is None: - return nn.MSELoss() - if criterion not in compatible_losses: - raise ClinicaDLArgumentError( - f"Regression loss must be chosen in {compatible_losses}." - ) - return getattr(nn, criterion)() - - @staticmethod - def get_default_network(): - return "Conv5_FC3" diff --git a/clinicadl/utils/task_manager/task_manager.py b/clinicadl/utils/task_manager/task_manager.py deleted file mode 100644 index 460c20c5f..000000000 --- a/clinicadl/utils/task_manager/task_manager.py +++ /dev/null @@ -1,321 +0,0 @@ -from abc import abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import pandas as pd -import torch -import torch.distributed as dist -from torch import Tensor -from torch.cuda.amp import autocast -from torch.nn.modules.loss import _Loss -from torch.utils.data import DataLoader, Sampler - -from clinicadl.caps_dataset.data import CapsDataset -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.network.network import Network -from clinicadl.utils import cluster - - -# TODO: add function to check that the output size of the network corresponds to what is expected to -# perform the task -class TaskManager: - def __init__(self, mode: str, n_classes: int = None): - self.mode = mode - self.metrics_module = MetricModule(self.evaluation_metrics, n_classes=n_classes) - - @property - @abstractmethod - def columns(self): - """ - List of the columns' names in the TSV file containing the predictions. - """ - pass - - @property - @abstractmethod - def evaluation_metrics(self): - """ - Evaluation metrics which can be used to evaluate the task. - """ - pass - - @property - @abstractmethod - def save_outputs(self): - """ - Boolean value indicating if the output values should be saved as tensor for this task. - """ - pass - - @abstractmethod - def generate_test_row( - self, idx: int, data: Dict[str, Any], outputs: Tensor - ) -> List[List[Any]]: - """ - Computes an individual row of the prediction TSV file. - - Args: - idx: index of the individual input and output in the batch. - data: input batch generated by a DataLoader on a CapsDataset. - outputs: output batch generated by a forward pass in the model. - Returns: - list of items to be contained in a row of the prediction TSV file. - """ - pass - - @abstractmethod - def compute_metrics(self, results_df: pd.DataFrame) -> Dict[str, float]: - """ - Compute the metrics based on the result of generate_test_row - - Args: - results_df: results generated based on _results_test_row - Returns: - dictionary of metrics - """ - pass - - @abstractmethod - def ensemble_prediction( - self, - performance_df: pd.DataFrame, - validation_df: pd.DataFrame, - selection_threshold: float = None, - use_labels: bool = True, - method: str = "soft", - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Compute the results at the image-level by assembling the results on parts of the image. - - Args: - performance_df: results that need to be assembled. - validation_df: results on the validation set used to compute the performance - of each separate part of the image. - selection_threshold: with soft-voting method, allows to exclude some parts of the image - if their associated performance is too low. - use_labels: If True, metrics are computed and the label column values must be different - from None. - method: method to assemble the results. Current implementation proposes soft or hard-voting. - - Returns: - the results and metrics on the image level - """ - pass - - @staticmethod - @abstractmethod - def generate_label_code(df: pd.DataFrame, label: str) -> Optional[Dict[str, int]]: - """ - Generates a label code that links the output node number to label value. - - Args: - df: meta-data of the training set. - label: name of the column containing the labels. - Returns: - label_code - """ - pass - - @staticmethod - @abstractmethod - def output_size( - input_size: Sequence[int], df: pd.DataFrame, label: str - ) -> Sequence[int]: - """ - Computes the output_size needed to perform the task. - - Args: - input_size: size of the input. - df: meta-data of the training set. - label: name of the column containing the labels. - Returns: - output_size - """ - pass - - @staticmethod - @abstractmethod - def generate_sampler( - dataset: CapsDataset, - sampler_option: str = "random", - n_bins: int = 5, - dp_degree: Optional[int] = None, - rank: Optional[int] = None, - ) -> Sampler: - """ - Returns sampler according to the wanted options. - - Args: - dataset: the dataset to sample from. - sampler_option: choice of sampler. - n_bins: number of bins to used for a continuous variable (regression task). - dp_degree: the degree of data parallelism. - rank: process id within the data parallelism communicator. - Returns: - callable given to the training data loader. - """ - pass - - @staticmethod - @abstractmethod - def get_criterion(criterion: str = None) -> _Loss: - """ - Gives the optimization criterion. - Must check that it is compatible with the task. - - Args: - criterion: name of the loss as written in Pytorch. - Raises: - ClinicaDLArgumentError: if the criterion is not compatible with the task. - """ - pass - - @staticmethod - @abstractmethod - def get_default_network() -> Network: - """Returns the default network to use when no architecture is specified.""" - pass - - def test( - self, - model: Network, - dataloader: DataLoader, - criterion: _Loss, - use_labels: bool = True, - amp: bool = False, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Parameters - ---------- - model: Network - The model trained. - dataloader: DataLoader - Wrapper of a CapsDataset. - criterion: _Loss - Function to calculate the loss. - use_labels: bool - If True the true_label will be written in output DataFrame - and metrics dict will be created. - amp: bool - If True, enables Pytorch's automatic mixed precision. - - Returns - ------- - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - - results_df = pd.DataFrame(columns=self.columns) - total_loss = {} - with torch.no_grad(): - for i, data in enumerate(dataloader): - # initialize the loss list to save the loss components - with autocast(enabled=amp): - outputs, loss_dict = model(data, criterion, use_labels=use_labels) - - if i == 0: - for loss_component in loss_dict.keys(): - total_loss[loss_component] = 0 - for loss_component in total_loss.keys(): - total_loss[loss_component] += loss_dict[loss_component].float() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = self.generate_test_row(idx, data, outputs.float()) - row_df = pd.DataFrame(row, columns=self.columns) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - dataframes = [None] * dist.get_world_size() - dist.gather_object( - results_df, dataframes if dist.get_rank() == 0 else None, dst=0 - ) - if dist.get_rank() == 0: - results_df = pd.concat(dataframes) - del dataframes - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = self.compute_metrics(results_df, report_ci=report_ci) - for loss_component in total_loss.keys(): - dist.reduce(total_loss[loss_component], dst=0) - loss_value = total_loss[loss_component].item() / cluster.world_size - - if report_ci: - metrics_dict["Metric_names"].append(loss_component) - metrics_dict["Metric_values"].append(loss_value) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict[loss_component] = loss_value - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def test_da( - self, - model: Network, - dataloader: DataLoader, - criterion: _Loss, - alpha: float = 0, - use_labels: bool = True, - target: bool = True, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Args: - model: the model trained. - dataloader: wrapper of a CapsDataset. - criterion: function to calculate the loss. - use_labels: If True the true_label will be written in output DataFrame - and metrics dict will be created. - Returns: - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - results_df = pd.DataFrame(columns=self.columns) - total_loss = 0 - with torch.no_grad(): - for i, data in enumerate(dataloader): - outputs, loss_dict = model.compute_outputs_and_loss_test( - data, criterion, alpha, target - ) - total_loss += loss_dict["loss"].item() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = self.generate_test_row(idx, data, outputs) - row_df = pd.DataFrame(row, columns=self.columns) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = self.compute_metrics(results_df, report_ci=report_ci) - if report_ci: - metrics_dict["Metric_names"].append("loss") - metrics_dict["Metric_values"].append(total_loss) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict["loss"] = total_loss - - torch.cuda.empty_cache() - - return results_df, metrics_dict diff --git a/clinicadl/validation/cross_validation.py b/clinicadl/validation/cross_validation.py index ff3fb489e..96fdfe1b9 100644 --- a/clinicadl/validation/cross_validation.py +++ b/clinicadl/validation/cross_validation.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt -from clinicadl.maps_manager.maps_manager import MapsManager +# from clinicadl.maps_manager.maps_manager import MapsManager from clinicadl.splitter.split_utils import find_splits logger = getLogger("clinicadl.cross_validation_config") @@ -32,7 +32,9 @@ def validator_split(cls, v): return tuple(v) return v # TODO : check that split exists (and check coherence with n_splits) - def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager): + def adapt_cross_val_with_maps_manager_info( + self, maps_manager + ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: self.split = find_splits(maps_manager.maps_path, maps_manager.split_name)