diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index 0581b879a..d144c1597 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -2,7 +2,11 @@ from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.caps_dataset.data import return_dataset +from clinicadl.predictor.config import PredictConfig +from clinicadl.predictor.predictor import Predictor from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.config.classification import ClassificationConfig from clinicadl.trainer.trainer import Trainer from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task @@ -27,11 +31,11 @@ multi_cohort, ) -split_config = SplitConfig() +split_config = SplitterConfig() splitter = Splitter(split_config) -validator_config = ValidatorConfig() -validator = Validator(validator_config) +validator_config = PredictConfig() +validator = Predictor(validator_config) train_config = ClassificationConfig() trainer = Trainer(train_config, validator) @@ -78,6 +82,6 @@ test_loader = trainer.get_dataloader(dataset, split, network, "test", config) validator.predict(test_loader) -interpret_config = InterpretConfig(**kwargs) -predict_manager = PredictManager(interpret_config) +interpret_config = PredictConfig(**kwargs) +predict_manager = Predictor(interpret_config) predict_manager.interpret() diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/caps_dataset/data_config.py index 35aed91b5..80694fcd0 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/caps_dataset/data_config.py @@ -24,7 +24,7 @@ class DataConfig(BaseModel): # TODO : put in data module that must be passed by the user. """ - caps_directory: Path + caps_directory: Optional[Path] = None baseline: bool = False diagnoses: Tuple[str, ...] = ("AD", "CN") data_df: Optional[pd.DataFrame] = None @@ -147,15 +147,17 @@ def preprocessing_dict(self) -> Dict[str, Any]: f"in {caps_dict}." ) - preprocessing_dict = read_preprocessing(preprocessing_json) + preprocessing_dict = read_preprocessing(preprocessing_json) - if ( - preprocessing_dict["mode"] == "roi" - and "roi_background_value" not in preprocessing_dict - ): - preprocessing_dict["roi_background_value"] = 0 + if ( + preprocessing_dict["mode"] == "roi" + and "roi_background_value" not in preprocessing_dict + ): + preprocessing_dict["roi_background_value"] = 0 - return preprocessing_dict + return preprocessing_dict + else: + return None @computed_field @property diff --git a/clinicadl/commandline/modules_options/ssda.py b/clinicadl/commandline/modules_options/ssda.py deleted file mode 100644 index 8119726ef..000000000 --- a/clinicadl/commandline/modules_options/ssda.py +++ /dev/null @@ -1,45 +0,0 @@ -import click - -from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type - -# SSDA -caps_target = click.option( - "--caps_target", - "-d", - type=get_type("caps_target", SSDAConfig), - default=get_default("caps_target", SSDAConfig), - help="CAPS of target data.", - show_default=True, -) -preprocessing_json_target = click.option( - "--preprocessing_json_target", - "-d", - type=get_type("preprocessing_json_target", SSDAConfig), - default=get_default("preprocessing_json_target", SSDAConfig), - help="Path to json target.", - show_default=True, -) -ssda_network = click.option( - "--ssda_network/--single_network", - default=get_default("ssda_network", SSDAConfig), - help="If provided uses a ssda-network framework.", - show_default=True, -) -tsv_target_lab = click.option( - "--tsv_target_lab", - "-d", - type=get_type("tsv_target_lab", SSDAConfig), - default=get_default("tsv_target_lab", SSDAConfig), - help="TSV of labeled target data.", - show_default=True, -) -tsv_target_unlab = click.option( - "--tsv_target_unlab", - "-d", - type=get_type("tsv_target_unlab", SSDAConfig), - default=get_default("tsv_target_unlab", SSDAConfig), - help="TSV of unllabeled target data.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py index 858dd956e..089357866 100644 --- a/clinicadl/commandline/modules_options/validation.py +++ b/clinicadl/commandline/modules_options/validation.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.predictor.validation import ValidationConfig # Validation valid_longitudinal = click.option( diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index b48651811..4798dc904 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -118,7 +118,6 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame: if caps_config.data.mask_path is None: caps_config.data.mask_path = get_mask_path() path_to_mask = caps_config.data.mask_path / f"mask-{label + 1}.nii" - print(path_to_mask) if path_to_mask.is_file(): atlas_to_mask = nib.loadsave.load(path_to_mask).get_fdata() else: diff --git a/clinicadl/commandline/pipelines/interpret/cli.py b/clinicadl/commandline/pipelines/interpret/cli.py index 3509eaf23..db92f06e2 100644 --- a/clinicadl/commandline/pipelines/interpret/cli.py +++ b/clinicadl/commandline/pipelines/interpret/cli.py @@ -1,3 +1,5 @@ +from pathlib import Path + import click from clinicadl.commandline import arguments @@ -10,7 +12,7 @@ ) from clinicadl.commandline.pipelines.interpret import options from clinicadl.interpret.config import InterpretConfig -from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predictor.predictor import Predictor @click.command("interpret", no_args_is_help=True) @@ -40,9 +42,13 @@ def cli(**kwargs): NAME is the name of the saliency map task. METHOD is the method used to extract an attribution map. """ + from clinicadl.utils.iotools.train_utils import merge_cli_and_maps_json_options - interpret_config = InterpretConfig(**kwargs) - predict_manager = PredictManager(interpret_config) + dict_ = merge_cli_and_maps_json_options( + Path(kwargs["input_maps"]) / "maps.json", **kwargs + ) + interpret_config = InterpretConfig(**dict_) + predict_manager = Predictor(interpret_config) predict_manager.interpret() diff --git a/clinicadl/commandline/pipelines/interpret/options.py b/clinicadl/commandline/pipelines/interpret/options.py index 5313b4a90..43cada4c4 100644 --- a/clinicadl/commandline/pipelines/interpret/options.py +++ b/clinicadl/commandline/pipelines/interpret/options.py @@ -2,28 +2,28 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.interpret.config import InterpretConfig +from clinicadl.interpret.config import InterpretBaseConfig # interpret specific name = click.argument( "name", - type=get_type("name", InterpretConfig), + type=get_type("name", InterpretBaseConfig), ) method = click.argument( "method", - type=get_type("method", InterpretConfig), # ["gradients", "grad-cam"] + type=get_type("method", InterpretBaseConfig), # ["gradients", "grad-cam"] ) level = click.option( "--level_grad_cam", - type=get_type("level", InterpretConfig), - default=get_default("level", InterpretConfig), + type=get_type("level", InterpretBaseConfig), + default=get_default("level", InterpretBaseConfig), help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", show_default=True, ) target_node = click.option( "--target_node", - type=get_type("target_node", InterpretConfig), - default=get_default("target_node", InterpretConfig), + type=get_type("target_node", InterpretBaseConfig), + default=get_default("target_node", InterpretBaseConfig), help="Which target node the gradients explain. Default takes the first output node.", show_default=True, ) diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index fa7303008..184f46ad7 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -10,8 +10,8 @@ validation, ) from clinicadl.commandline.pipelines.predict import options -from clinicadl.predict.config import PredictConfig -from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predictor.config import PredictConfig +from clinicadl.predictor.predictor import Predictor @click.command(name="predict", no_args_is_help=True) @@ -61,7 +61,7 @@ def cli(input_maps_directory, data_group, **kwargs): """ predict_config = PredictConfig(**kwargs) - predict_manager = PredictManager(predict_config) + predict_manager = Predictor(predict_config) predict_manager.predict() diff --git a/clinicadl/commandline/pipelines/predict/options.py b/clinicadl/commandline/pipelines/predict/options.py index 003dfe275..cbb8980ca 100644 --- a/clinicadl/commandline/pipelines/predict/options.py +++ b/clinicadl/commandline/pipelines/predict/options.py @@ -1,13 +1,11 @@ import click from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.predict.config import PredictConfig +from clinicadl.predictor.config import PredictConfig # predict specific use_labels = click.option( "--use_labels/--no_labels", - show_default=True, - default=get_default("use_labels", PredictConfig), help="Set this option to --no_labels if your dataset does not contain ground truth labels.", ) save_tensor = click.option( diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 539f6cd42..8ac287402 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split @@ -115,4 +108,5 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) config = ClassificationConfig(**options) trainer = Trainer(config) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py index c0130a9b9..517ec8fa5 100644 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ b/clinicadl/commandline/pipelines/train/from_json/cli.py @@ -27,6 +27,8 @@ def cli(**kwargs): logger.info(f"Reading JSON file at path {kwargs['config_file']}...") trainer = Trainer.from_json( - config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"] + config_file=kwargs["config_file"], + maps_path=kwargs["output_maps_directory"], + split=kwargs["split"], ) trainer.train(split_list=kwargs["split"], overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index d63bf63f8..fc39ef54e 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index ff6dd68ca..59e816192 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -61,12 +60,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda o -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index 1fc34a0f4..12451d18a 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -16,4 +16,4 @@ def cli(input_maps_directory, split): INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started. """ trainer = Trainer.from_maps(input_maps_directory) - trainer.resume(split) + trainer.resume() diff --git a/clinicadl/config/config/ssda.py b/clinicadl/config/config/ssda.py deleted file mode 100644 index caf52634d..000000000 --- a/clinicadl/config/config/ssda.py +++ /dev/null @@ -1,41 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict - -from pydantic import BaseModel, ConfigDict, computed_field - -from clinicadl.utils.iotools.utils import read_preprocessing - -logger = getLogger("clinicadl.ssda_config") - - -class SSDAConfig(BaseModel): - """Config class to perform SSDA.""" - - caps_target: Path = Path("") - preprocessing_json_target: Path = Path("") - ssda_network: bool = False - tsv_target_lab: Path = Path("") - tsv_target_unlab: Path = Path("") - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @computed_field - @property - def preprocessing_dict_target(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a target preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - """ - if not self.ssda_network: - return {} - - preprocessing_json_target = ( - self.caps_target / "tensor_extraction" / self.preprocessing_json_target - ) - - return read_preprocessing(preprocessing_json_target) diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index 41c8dcea9..03036d4c9 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -1,23 +1,33 @@ from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, field_validator -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig +from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp -from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.maps_manager.config import MapsManagerConfig as MapsManagerConfigBase +from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod +from clinicadl.utils.exceptions import ClinicaDLArgumentError logger = getLogger("clinicadl.interpret_config") -class DataConfig(DataBaseConfig): - caps_directory: Optional[Path] = None +class MapsManagerConfig(MapsManagerConfigBase): + save_tensor: bool = False + + def check_output_saving_tensor(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_tensor and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." + ) class InterpretBaseConfig(BaseModel): @@ -44,13 +54,57 @@ def get_method(self) -> Gradients: raise ValueError(f"The method {self.method.value} is not implemented") -class InterpretConfig( - MapsManagerConfig, - InterpretBaseConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): +class InterpretConfig(BaseModel): """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + interpret: InterpretBaseConfig + + def __init__(self, **kwargs): + super().__init__( + maps_manager=kwargs, + computational=kwargs, + dataloader=kwargs, + data=kwargs, + split=kwargs, + validation=kwargs, + interpret=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) + self.maps_manager.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.computational.__dict__.update(config_dict) + self.dataloader.__dict__.update(config_dict) + self.interpret.__dict__.update(config_dict) + + def adapt_with_maps_manager_info(self, maps_manager: MapsManager): + self.maps_manager.check_output_saving_nifti(maps_manager.network_task) + self.data.diagnoses = ( + maps_manager.diagnoses + if self.data.diagnoses is None or len(self.data.diagnoses) == 0 + else self.data.diagnoses + ) + + self.dataloader.batch_size = ( + maps_manager.batch_size + if not self.dataloader.batch_size + else self.dataloader.batch_size + ) + self.dataloader.n_proc = ( + maps_manager.n_proc + if not self.dataloader.n_proc + else self.dataloader.n_proc + ) + + self.split.adapt_cross_val_with_maps_manager_info(maps_manager) + self.maps_manager.check_output_saving_tensor(maps_manager.network_task) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 76cb544fe..10550a021 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -17,7 +17,7 @@ from clinicadl.metrics.utils import ( check_selection_metric, ) -from clinicadl.predict.utils import get_prediction +from clinicadl.predictor.utils import get_prediction from clinicadl.splitter.config import SplitterConfig from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.tasks_utils import ( diff --git a/clinicadl/nn/networks/__init__.py b/clinicadl/nn/networks/__init__.py index c77097e60..3b88830fb 100644 --- a/clinicadl/nn/networks/__init__.py +++ b/clinicadl/nn/networks/__init__.py @@ -8,7 +8,6 @@ resnet18, ) from .random import RandomArchitecture -from .ssda import Conv5_FC3_SSDA from .unet import UNet from .vae import ( CVAE_3D, diff --git a/clinicadl/nn/networks/cnn.py b/clinicadl/nn/networks/cnn.py index eb2104b1e..5fe596bcb 100644 --- a/clinicadl/nn/networks/cnn.py +++ b/clinicadl/nn/networks/cnn.py @@ -63,8 +63,6 @@ def __init__(self, convolution_layers: nn.Module, fc_layers: nn.Module) -> None: def forward(self, x): inter = self.convolutions(x) - print(self.convolutions) - print(inter.shape) return self.fc(inter) diff --git a/clinicadl/nn/utils.py b/clinicadl/nn/utils.py index dc3afd71c..263afc407 100644 --- a/clinicadl/nn/utils.py +++ b/clinicadl/nn/utils.py @@ -64,7 +64,6 @@ def compute_output_size( input_ = torch.randn(input_size).unsqueeze(0) if isinstance(layer, nn.MaxUnpool3d) or isinstance(layer, nn.MaxUnpool2d): indices = torch.zeros_like(input_, dtype=int) - print(indices) output = layer(input_, indices) else: output = layer(input_) diff --git a/clinicadl/predict/config.py b/clinicadl/predict/config.py deleted file mode 100644 index a96b4b104..000000000 --- a/clinicadl/predict/config.py +++ /dev/null @@ -1,40 +0,0 @@ -from logging import getLogger - -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.maps_manager.config import ( - MapsManagerConfig as MapsManagerBaseConfig, -) -from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig -from clinicadl.utils.computational.computational import ComputationalConfig -from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore - -logger = getLogger("clinicadl.predict_config") - - -class MapsManagerConfig(MapsManagerBaseConfig): - save_tensor: bool = False - save_latent_tensor: bool = False - - def check_output_saving_tensor(self, network_task: str) -> None: - # Check if task is reconstruction for "save_tensor" and "save_nifti" - if self.save_tensor and network_task != "reconstruction": - raise ClinicaDLArgumentError( - "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." - ) - - -class DataConfig(DataBaseConfig): - use_labels: bool = True - - -class PredictConfig( - MapsManagerConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): - """Config class to perform Transfer Learning.""" diff --git a/clinicadl/predict/__init__.py b/clinicadl/predictor/__init__.py similarity index 100% rename from clinicadl/predict/__init__.py rename to clinicadl/predictor/__init__.py diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py new file mode 100644 index 000000000..ead42d1c6 --- /dev/null +++ b/clinicadl/predictor/config.py @@ -0,0 +1,105 @@ +from logging import getLogger +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, computed_field + +from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig +from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.maps_manager.config import ( + MapsManagerConfig as MapsManagerBaseConfig, +) +from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.predictor.validation import ValidationConfig +from clinicadl.splitter.config import SplitConfig +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.computational.computational import ComputationalConfig +from clinicadl.utils.enum import Task +from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore + +logger = getLogger("clinicadl.predict_config") + + +class MapsManagerConfig(MapsManagerBaseConfig): + save_tensor: bool = False + save_latent_tensor: bool = False + + def check_output_saving_tensor(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_tensor and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." + ) + + +class DataConfig(DataBaseConfig): + use_labels: bool = True + + +class PredictConfig(BaseModel): + """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + transforms: TransformsConfig + + model_config = ConfigDict(validate_assignment=True) + + def __init__(self, **kwargs): + super().__init__( + maps_manager=kwargs, + computational=kwargs, + dataloader=kwargs, + data=kwargs, + split=kwargs, + validation=kwargs, + transforms=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) + self.maps_manager.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.computational.__dict__.update(config_dict) + self.dataloader.__dict__.update(config_dict) + self.transforms.__dict__.update(config_dict) + + def adapt_with_maps_manager_info(self, maps_manager: MapsManager): + self.maps_manager.check_output_saving_nifti(maps_manager.network_task) + self.data.diagnoses = ( + maps_manager.diagnoses + if self.data.diagnoses is None or len(self.data.diagnoses) == 0 + else self.data.diagnoses + ) + + self.dataloader.batch_size = ( + maps_manager.batch_size + if not self.dataloader.batch_size + else self.dataloader.batch_size + ) + self.dataloader.n_proc = ( + maps_manager.n_proc + if not self.dataloader.n_proc + else self.dataloader.n_proc + ) + + self.split.adapt_cross_val_with_maps_manager_info(maps_manager) + self.maps_manager.check_output_saving_tensor(maps_manager.network_task) + + self.transforms = TransformsConfig( + normalize=maps_manager.normalize, + data_augmentation=maps_manager.data_augmentation, + size_reduction=maps_manager.size_reduction, + size_reduction_factor=maps_manager.size_reduction_factor, + ) + + if self.split.split is None and self.split.n_splits == 0: + from clinicadl.splitter.split_utils import find_splits + + self.split.split = find_splits(self.maps_manager.maps_dir) diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predictor/predictor.py similarity index 54% rename from clinicadl/predict/predict_manager.py rename to clinicadl/predictor/predictor.py index 55515dc8e..30fbbe5b8 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predictor/predictor.py @@ -2,12 +2,13 @@ import shutil from logging import getLogger from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import pandas as pd import torch import torch.distributed as dist from torch.amp import autocast +from torch.nn.modules.loss import _Loss from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -16,314 +17,123 @@ ) from clinicadl.interpret.config import InterpretConfig from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.metrics.metric_module import MetricModule from clinicadl.metrics.utils import ( check_selection_metric, find_selection_metrics, ) -from clinicadl.predict.config import PredictConfig -from clinicadl.trainer.tasks_utils import generate_label_code, get_criterion +from clinicadl.network.network import Network +from clinicadl.predictor.config import PredictConfig +from clinicadl.trainer.tasks_utils import ( + columns, + compute_metrics, + generate_label_code, + generate_test_row, + get_criterion, +) from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.ddp import DDP, cluster +from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLDataLeakageError, MAPSError, ) -from clinicadl.validator.validator import Validator logger = getLogger("clinicadl.predict_manager") level_list: List[str] = ["warning", "info", "debug"] -class PredictManager: +class Predictor: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: - self.maps_manager = MapsManager(_config.maps_dir) self._config = _config - self.validator = Validator() + + from clinicadl.splitter.config import SplitterConfig + from clinicadl.splitter.splitter import Splitter + + self.maps_manager = MapsManager(_config.maps_manager.maps_dir) + self._config.adapt_with_maps_manager_info(self.maps_manager) + tmp = self._config.data.model_dump( + exclude=set(["preprocessing_dict", "mode", "caps_dict"]) + ) + tmp.update(self._config.split.model_dump()) + tmp.update(self._config.validation.model_dump()) + self.splitter = Splitter(SplitterConfig(**tmp)) def predict( self, label_code: Union[str, dict[str, int]] = "default", ): - """Performs the prediction task on a subset of caps_directory defined in a TSV file. - Parameters - ---------- - data_group : str - name of the data group tested. - caps_directory : Path (optional, default=None) - path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group - tsv_path : Path (optional, default=None) - path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group - split_list : List[int] (optional, default=None) - list of splits to test. Default perform prediction on all splits available. - selection_metrics : List[str] (optional, default=None) - list of selection metrics to test. - Default performs the prediction on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : List[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - use_labels : bool (optional, default=True) - If True, the labels must exist in test meta-data and metrics are computed. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - label : str (optional, default=None) - Target label used for training (if network_task in [`regression`, `classification`]). - label_code : Optional[Dict[str, int]] (optional, default="default") - dictionary linking the target values to a node number. - save_tensor : bool (optional, default=False) - If true, save the tensor predicted for reconstruction task - save_nifti : bool (optional, default=False) - If true, save the nifti associated to the prediction for reconstruction task. - save_latent_tensor : bool (optional, default=False) - If true, save the tensor from the latent space for reconstruction task. - skip_leak_check : bool (optional, default=False) - If true, skip the leak check (not recommended). - Examples - -------- - >>> _input_ - _output_ - """ - - assert isinstance(self._config, PredictConfig) - - self._config.check_output_saving_nifti(self.maps_manager.network_task) - self._config.diagnoses = ( - self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses - ) - - self._config.batch_size = ( - self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size - ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc - ) + """Performs the prediction task on a subset of caps_directory defined in a TSV file.""" - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) - self._config.check_output_saving_tensor(self.maps_manager.network_task) - - transforms = TransformsConfig( - normalize=self.maps_manager.normalize, - data_augmentation=self.maps_manager.data_augmentation, - size_reduction=self.maps_manager.size_reduction, - size_reduction_factor=self.maps_manager.size_reduction_factor, - ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) criterion = get_criterion( self.maps_manager.network_task, self.maps_manager.loss ) - self._check_data_group(df=group_df) - assert self._config.split # don't know if needed ? try to raise an exception ? - # assert self._config.label - - for split in self._config.split: + for split in self.splitter.split_iterator(): logger.info(f"Prediction of split {split}") group_df, group_parameters = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) # Find label code if not given - if self._config.is_given_label_code(self.maps_manager.label, label_code): + if self._config.data.is_given_label_code( + self.maps_manager.label, label_code + ): generate_label_code( - self.maps_manager.network_task, group_df, self._config.label + self.maps_manager.network_task, group_df, self._config.data.label ) # Erase previous TSV files on master process - if not self._config.selection_metrics: + if not self._config.validation.selection_metrics: split_selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) else: - split_selection_metrics = self._config.selection_metrics + split_selection_metrics = self._config.validation.selection_metrics for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) - tsv_pattern = f"{self._config.data_group}*.tsv" + tsv_pattern = f"{self._config.maps_manager.data_group}*.tsv" for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() - self._config.check_label(self.maps_manager.label) + + self._config.data.check_label(self.maps_manager.label) if self.maps_manager.multi_network: - self._predict_multi( - group_parameters, - group_df, - transforms, - label_code, - criterion, - split, - split_selection_metrics, - ) + for network in range(self.maps_manager.num_networks): + self._predict_single( + group_parameters, + group_df, + self._config.transforms, + label_code, + criterion, + split, + split_selection_metrics, + network, + ) else: self._predict_single( group_parameters, group_df, - transforms, + self._config.transforms, label_code, criterion, split, split_selection_metrics, ) if cluster.master: - self.validator._ensemble_prediction( - self.maps_manager, - self._config.data_group, - split, - self._config.selection_metrics, - self._config.use_labels, - self._config.skip_leak_check, - ) - - def _predict_multi( - self, - group_parameters, - group_df, - transforms, - label_code, - criterion, - split, - split_selection_metrics, - ): - """_summary_ - Parameters - ---------- - group_parameters : _type_ - _description_ - group_df : _type_ - _description_ - all_transforms : _type_ - _description_ - use_labels : _type_ - _description_ - label : _type_ - _description_ - label_code : _type_ - _description_ - batch_size : _type_ - _description_ - n_proc : _type_ - _description_ - criterion : _type_ - _description_ - data_group : _type_ - _description_ - split : _type_ - _description_ - split_selection_metrics : _type_ - _description_ - gpu : _type_ - _description_ - amp : _type_ - _description_ - save_tensor : _type_ - _description_ - save_latent_tensor : _type_ - _description_ - save_nifti : _type_ - _description_ - selection_metrics : _type_ - _description_ - Examples - -------- - >>> _input_ - _output_ - Notes - ----- - _notes_ - See Also - -------- - - _related_ - """ - assert isinstance(self._config, PredictConfig) - # assert self._config.label - - for network in range(self.maps_manager.num_networks): - data_test = return_dataset( - group_parameters["caps_directory"], - group_df, - self.maps_manager.preprocessing_dict, - transforms_config=transforms, - multi_cohort=group_parameters["multi_cohort"], - label_presence=self._config.use_labels, - label=self._config.label, - label_code=( - self.maps_manager.label_code - if label_code == "default" - else label_code - ), - cnn_index=network, - ) - test_loader = DataLoader( - data_test, - batch_size=( - self._config.batch_size - if self._config.batch_size is not None - else self.maps_manager.batch_size - ), - shuffle=False, - sampler=DistributedSampler( - data_test, - num_replicas=cluster.world_size, - rank=cluster.rank, - shuffle=False, - ), - num_workers=self._config.n_proc - if self._config.n_proc is not None - else self.maps_manager.n_proc, - ) - self.validator._test_loader( - maps_manager=self.maps_manager, - dataloader=test_loader, - criterion=criterion, - data_group=self._config.data_group, - split=split, - selection_metrics=split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, - network=network, - ) - if self._config.save_tensor: - logger.debug("Saving tensors") - self.validator._compute_output_tensors( + self._ensemble_prediction( self.maps_manager, - data_test, - self._config.data_group, + self._config.maps_manager.data_group, split, - self._config.selection_metrics, - gpu=self._config.gpu, - network=network, - ) - if self._config.save_nifti: - self._compute_output_nifti( - data_test, - split, - network=network, - ) - if self._config.save_latent_tensor: - self._compute_latent_tensors( - dataset=data_test, - split=split, - network=network, + self._config.validation.selection_metrics, + self._config.data.use_labels, + self._config.validation.skip_leak_check, ) def _predict_single( @@ -335,78 +145,31 @@ def _predict_single( criterion, split, split_selection_metrics, + network: Optional[int] = None, ): - """_summary_ - Parameters - ---------- - group_parameters : _type_ - _description_ - group_df : _type_ - _description_ - all_transforms : _type_ - _description_ - use_labels : _type_ - _description_ - label : _type_ - _description_ - label_code : _type_ - _description_ - batch_size : _type_ - _description_ - n_proc : _type_ - _description_ - criterion : _type_ - _description_ - data_group : _type_ - _description_ - split : _type_ - _description_ - split_selection_metrics : _type_ - _description_ - gpu : _type_ - _description_ - amp : _type_ - _description_ - save_tensor : _type_ - _description_ - save_latent_tensor : _type_ - _description_ - save_nifti : _type_ - _description_ - selection_metrics : _type_ - _description_ - Examples - -------- - >>> _input_ - _output_ - Notes - ----- - _notes_ - See Also - -------- - - _related_ - """ + """_summary_""" assert isinstance(self._config, PredictConfig) - # assert self._config.label + # assert self._config.data.label data_test = return_dataset( group_parameters["caps_directory"], group_df, self.maps_manager.preprocessing_dict, - transforms_config=transforms, + transforms_config=self._config.transforms, multi_cohort=group_parameters["multi_cohort"], - label_presence=self._config.use_labels, - label=self._config.label, + label_presence=self._config.data.use_labels, + label=self._config.data.label, label_code=( self.maps_manager.label_code if label_code == "default" else label_code ), + cnn_index=network, ) test_loader = DataLoader( data_test, batch_size=( - self._config.batch_size - if self._config.batch_size is not None + self._config.dataloader.batch_size + if self._config.dataloader.batch_size is not None else self.maps_manager.batch_size ), shuffle=False, @@ -416,40 +179,44 @@ def _predict_single( rank=cluster.rank, shuffle=False, ), - num_workers=self._config.n_proc - if self._config.n_proc is not None + num_workers=self._config.dataloader.n_proc + if self._config.dataloader.n_proc is not None else self.maps_manager.n_proc, ) - self.validator._test_loader( - self.maps_manager, - test_loader, - criterion, - self._config.data_group, - split, - split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, + self._test_loader( + maps_manager=self.maps_manager, + dataloader=test_loader, + criterion=criterion, + data_group=self._config.maps_manager.data_group, + split=split, + selection_metrics=split_selection_metrics, + use_labels=self._config.data.use_labels, + gpu=self._config.computational.gpu, + amp=self._config.computational.amp, + network=network, ) - if self._config.save_tensor: + if self._config.maps_manager.save_tensor: logger.debug("Saving tensors") - self.validator._compute_output_tensors( - self.maps_manager, - data_test, - self._config.data_group, - split, - self._config.selection_metrics, - gpu=self._config.gpu, + self._compute_output_tensors( + maps_manager=self.maps_manager, + dataset=data_test, + data_group=self._config.maps_manager.data_group, + split=split, + selection_metrics=self._config.validation.selection_metrics, + gpu=self._config.computational.gpu, + network=network, ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: self._compute_output_nifti( - data_test, - split, + dataset=data_test, + split=split, + network=network, ) - if self._config.save_latent_tensor: + if self._config.maps_manager.save_latent_tensor: self._compute_latent_tensors( dataset=data_test, split=split, + network=network, ) def _compute_latent_tensors( @@ -478,13 +245,13 @@ def _compute_latent_tensors( network : _type_ (optional, default=None) Index of the network tested (only used in multi-network setting). """ - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -498,7 +265,7 @@ def _compute_latent_tensors( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "latent_tensors" ) if cluster.master: @@ -555,13 +322,13 @@ def _compute_output_nifti( import nibabel as nib from numpy import eye - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -575,7 +342,7 @@ def _compute_output_nifti( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "nifti_images" ) if cluster.master: @@ -608,77 +375,10 @@ def _compute_output_nifti( def interpret(self): """Performs the interpretation task on a subset of caps_directory defined in a TSV file. The mean interpretation is always saved, to save the individual interpretations set save_individual to True. - Parameters - ---------- - data_group : str - Name of the data group interpreted. - name : str - Name of the interpretation procedure. - method : str - Method used for extraction (ex: gradients, grad-cam...). - caps_directory : Path (optional, default=None) - Path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group. - tsv_path : Path (optional, default=None) - Path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group. - split_list : list[int] (optional, default=None) - List of splits to interpret. Default perform interpretation on all splits available. - selection_metrics : list[str] (optional, default=None) - List of selection metrics to interpret. - Default performs the interpretation on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : list[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - target_node : int (optional, default=0) - Node from which the interpretation is computed. - save_individual : bool (optional, default=False) - If True saves the individual map of each participant / session couple. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - overwrite_name : bool (optional, default=False) - If True erase the occurrences of name. - level : int (optional, default=None) - Layer number in the convolutional part after which the feature map is chosen. - save_nifti : bool (optional, default=False) - If True, save the interpretation map in nifti format. - Raises - ------ - NotImplementedError - If the method is not implemented - NotImplementedError - If the interpretaion of multi network is asked - MAPSError - If the interpretation has already been determined. """ assert isinstance(self._config, InterpretConfig) - self._config.diagnoses = ( - self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses - ) - self._config.batch_size = ( - self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size - ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc - ) - - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) + self._config.adapt_with_maps_manager_info(self.maps_manager) if self.maps_manager.multi_network: raise NotImplementedError( @@ -690,14 +390,13 @@ def interpret(self): size_reduction=self.maps_manager.size_reduction, size_reduction_factor=self.maps_manager.size_reduction_factor, ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) - assert self._config.split - for split in self._config.split: + for split in self.splitter.split_iterator(): logger.info(f"Interpretation of split {split}") df_group, parameters_group = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) data_test = return_dataset( parameters_group["caps_directory"], @@ -711,30 +410,30 @@ def interpret(self): ) test_loader = DataLoader( data_test, - batch_size=self._config.batch_size, + batch_size=self._config.dataloader.batch_size, shuffle=False, - num_workers=self._config.n_proc, + num_workers=self._config.dataloader.n_proc, ) - if not self._config.selection_metrics: - self._config.selection_metrics = find_selection_metrics( + if not self._config.validation.selection_metrics: + self._config.validation.selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group - / f"interpret-{self._config.name}" + / self._config.maps_manager.data_group + / f"interpret-{self._config.interpret.name}" ) if (results_path).is_dir(): - if self._config.overwrite_name: + if self._config.interpret.overwrite_name: shutil.rmtree(results_path) else: raise MAPSError( - f"Interpretation name {self._config.name} is already written. " + f"Interpretation name {self._config.interpret.name} is already written. " f"Please choose another name or set overwrite_name to True." ) results_path.mkdir(parents=True) @@ -742,28 +441,28 @@ def interpret(self): transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, ) - interpreter = self._config.get_method()(model) + interpreter = self._config.interpret.get_method()(model) cum_maps = [0] * data_test.elem_per_image for data in test_loader: images = data["image"].to(model.device) map_pt = interpreter.generate_gradients( images, - self._config.target_node, - level=self._config.level, - amp=self._config.amp, + self._config.interpret.target_node, + level=self._config.interpret.level, + amp=self._config.computational.amp, ) for i in range(len(data["participant_id"])): mode_id = data[f"{self.maps_manager.mode}_id"][i] cum_maps[mode_id] += map_pt[i] - if self._config.save_individual: + if self._config.interpret.save_individual: single_path = ( results_path / f"{data['participant_id'][i]}_{data['session_id'][i]}_{self.maps_manager.mode}-{data[f'{self.maps_manager.mode}_id'][i]}_map.pt" ) torch.save(map_pt[i], single_path) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -781,7 +480,7 @@ def interpret(self): mode_map, results_path / f"mean_{self.maps_manager.mode}-{i}_map.pt", ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -801,22 +500,6 @@ def _check_data_group( Parameters ---------- - data_group : str - name of the data group - caps_directory : str (optional, default=None) - input CAPS directory - df : pd.DataFrame (optional, default=None) - Table of participant_id / session_id of the data group - multi_cohort : bool (optional, default=False) - indicates if the input data comes from several CAPS - overwrite : bool (optional, default=False) - If True former definition of data group is erased - label : str (optional, default=None) - label name if applicable - split_list : list[int] (optional, default=None) - _description_ - skip_leak_check : bool (optional, default=False) - _description_ Raises ------ @@ -828,17 +511,21 @@ def _check_data_group( when caps_directory or df are not given and data group does not exist """ - group_dir = self.maps_manager.maps_path / "groups" / self._config.data_group + group_dir = ( + self.maps_manager.maps_path + / "groups" + / self._config.maps_manager.data_group + ) logger.debug(f"Group path {group_dir}") if group_dir.is_dir(): # Data group already exists - if self._config.overwrite: - if self._config.data_group in ["train", "validation"]: + if self._config.maps_manager.overwrite: + if self._config.maps_manager.data_group in ["train", "validation"]: raise MAPSError("Cannot overwrite train or validation data group.") else: - # if not split_list: - # split_list = self.maps_manager.find_splits() + if not self._config.split.split: + self._config.split.split = self.maps_manager.find_splits() assert self._config.split - for split in self._config.split: + for split in self._config.split.split: selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, @@ -848,40 +535,40 @@ def _check_data_group( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) if results_path.is_dir(): shutil.rmtree(results_path) elif df is not None or ( - self._config.caps_directory is not None - and self._config.caps_directory != Path("") + self._config.data.caps_directory is not None + and self._config.data.caps_directory != Path("") ): raise ClinicaDLArgumentError( - f"Data group {self._config.data_group} is already defined. " + f"Data group {self._config.maps_manager.data_group} is already defined. " f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. " - f"To erase {self._config.data_group} please set overwrite to True." + f"To erase {self._config.maps_manager.data_group} please set overwrite to True." ) elif not group_dir.is_dir() and ( - self._config.caps_directory is None or df is None + self._config.data.caps_directory is None or df is None ): # Data group does not exist yet / was overwritten + missing data raise ClinicaDLArgumentError( - f"The data group {self._config.data_group} does not already exist. " + f"The data group {self._config.maps_manager.data_group} does not already exist. " f"Please specify a caps_directory and a tsv_path to create this data group." ) elif ( not group_dir.is_dir() ): # Data group does not exist yet / was overwritten + all data is provided - if self._config.skip_leak_check: + if self._config.validation.skip_leak_check: logger.info("Skipping data leakage check") else: - self._check_leakage(self._config.data_group, df) + self._check_leakage(self._config.maps_manager.data_group, df) self._write_data_group( - self._config.data_group, + self._config.maps_manager.data_group, df, - self._config.caps_directory, - self._config.multi_cohort, - label=self._config.label, + self._config.data.caps_directory, + self._config.data.multi_cohort, + label=self._config.data.label, ) def get_group_info( @@ -997,8 +684,8 @@ def _write_data_group( group_path.mkdir(parents=True) columns = ["participant_id", "session_id", "cohort"] - if self._config.label in df.columns.values: - columns += [self._config.label] + if self._config.data.label in df.columns.values: + columns += [self._config.data.label] if label is not None and label in df.columns.values: columns += [label] @@ -1088,3 +775,379 @@ def get_interpretation( weights_only=True, ) return map_pt + + def test( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task, + model: Network, + dataloader: DataLoader, + criterion: _Loss, + use_labels: bool = True, + amp: bool = False, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Parameters + ---------- + model: Network + The model trained. + dataloader: DataLoader + Wrapper of a CapsDataset. + criterion: _Loss + Function to calculate the loss. + use_labels: bool + If True the true_label will be written in output DataFrame + and metrics dict will be created. + amp: bool + If True, enables Pytorch's automatic mixed precision. + + Returns + ------- + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = {} + with torch.no_grad(): + for i, data in enumerate(dataloader): + # initialize the loss list to save the loss components + with autocast("cuda", enabled=amp): + outputs, loss_dict = model(data, criterion, use_labels=use_labels) + + if i == 0: + for loss_component in loss_dict.keys(): + total_loss[loss_component] = 0 + for loss_component in total_loss.keys(): + total_loss[loss_component] += loss_dict[loss_component].float() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs.float(), + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + dataframes = [None] * dist.get_world_size() + dist.gather_object( + results_df, dataframes if dist.get_rank() == 0 else None, dst=0 + ) + if dist.get_rank() == 0: + results_df = pd.concat(dataframes) + del dataframes + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + for loss_component in total_loss.keys(): + dist.reduce(total_loss[loss_component], dst=0) + loss_value = total_loss[loss_component].item() / cluster.world_size + + if report_ci: + metrics_dict["Metric_names"].append(loss_component) + metrics_dict["Metric_values"].append(loss_value) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict[loss_component] = loss_value + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def test_da( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task: Union[str, Task], + model: Network, + dataloader: DataLoader, + criterion: _Loss, + alpha: float = 0, + use_labels: bool = True, + target: bool = True, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Args: + model: the model trained. + dataloader: wrapper of a CapsDataset. + criterion: function to calculate the loss. + use_labels: If True the true_label will be written in output DataFrame + and metrics dict will be created. + Returns: + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = 0 + with torch.no_grad(): + for i, data in enumerate(dataloader): + outputs, loss_dict = model.compute_outputs_and_loss_test( + data, criterion, alpha, target + ) + total_loss += loss_dict["loss"].item() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs, + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + if report_ci: + metrics_dict["Metric_names"].append("loss") + metrics_dict["Metric_values"].append(total_loss) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict["loss"] = total_loss + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def _test_loader( + self, + maps_manager: MapsManager, + dataloader, + criterion, + data_group: str, + split: int, + selection_metrics, + use_labels=True, + gpu=None, + amp=False, + network=None, + report_ci=True, + ): + """ + Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. + + Args: + dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. + criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. + data_group (str): name of the data group used for the testing task. + split (int): Index of the split used to train the model tested. + selection_metrics (list[str]): List of metrics used to select the best models which are tested. + use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. + gpu (bool): If given, a new value for the device of the model will be computed. + amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + if cluster.master: + log_dir = ( + maps_manager.maps_path + / f"split-{split}" + / f"best-{selection_metric}" + / data_group + ) + maps_manager.write_description_log( + log_dir, + data_group, + dataloader.dataset.config.data.caps_dict, + dataloader.dataset.config.data.data_df, + ) + + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + + prediction_df, metrics = self.test( + mode=maps_manager.mode, + metrics_module=maps_manager.metrics_module, + n_classes=maps_manager.n_classes, + network_task=maps_manager.network_task, + model=model, + dataloader=dataloader, + criterion=criterion, + use_labels=use_labels, + amp=amp, + report_ci=report_ci, + ) + if use_labels: + if network is not None: + metrics[f"{maps_manager.mode}_id"] = network + + loss_to_log = ( + metrics["Metric_values"][-1] if report_ci else metrics["loss"] + ) + + logger.info( + f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" + ) + + if cluster.master: + # Replace here + maps_manager._mode_level_to_tsv( + prediction_df, + metrics, + split, + selection_metric, + data_group=data_group, + ) + + @torch.no_grad() + def _compute_output_tensors( + self, + maps_manager: MapsManager, + dataset, + data_group, + split, + selection_metrics, + nb_images=None, + gpu=None, + network=None, + ): + """ + Compute the output tensors and saves them in the MAPS. + + Args: + dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. + data_group (str): name of the data group used for the task. + split (int): split number. + selection_metrics (list[str]): metrics used for model selection. + nb_images (int): number of full images to write. Default computes the outputs of the whole data set. + gpu (bool): If given, a new value for the device of the model will be computed. + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + model.eval() + + tensor_path = ( + maps_manager.maps_path + / f"split-{split}" + / f"best-{selection_metric}" + / data_group + / "tensors" + ) + if cluster.master: + tensor_path.mkdir(parents=True, exist_ok=True) + dist.barrier() + + if nb_images is None: # Compute outputs for the whole data set + nb_modes = len(dataset) + else: + nb_modes = nb_images * dataset.elem_per_image + + for i in [ + *range(cluster.rank, nb_modes, cluster.world_size), + *range(int(nb_modes % cluster.world_size <= cluster.rank)), + ]: + data = dataset[i] + image = data["image"] + x = image.unsqueeze(0).to(model.device) + with autocast("cuda", enabled=maps_manager.std_amp): + output = model(x) + output = output.squeeze(0).cpu().float() + participant_id = data["participant_id"] + session_id = data["session_id"] + mode_id = data[f"{maps_manager.mode}_id"] + input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" + output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" + torch.save(image, tensor_path / input_filename) + torch.save(output, tensor_path / output_filename) + logger.debug(f"File saved at {[input_filename, output_filename]}") + + def _ensemble_prediction( + self, + maps_manager: MapsManager, + data_group, + split, + selection_metrics, + use_labels=True, + skip_leak_check=False, + ): + """Computes the results on the image-level.""" + + if not selection_metrics: + selection_metrics = find_selection_metrics(maps_manager.maps_path, split) + + for selection_metric in selection_metrics: + ##################### + # Soft voting + if maps_manager.num_networks > 1 and not skip_leak_check: + maps_manager._ensemble_to_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + ) + elif maps_manager.mode != "image" and not skip_leak_check: + maps_manager._mode_to_image_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + ) diff --git a/clinicadl/predict/utils.py b/clinicadl/predictor/utils.py similarity index 100% rename from clinicadl/predict/utils.py rename to clinicadl/predictor/utils.py diff --git a/clinicadl/splitter/validation.py b/clinicadl/predictor/validation.py similarity index 100% rename from clinicadl/splitter/validation.py rename to clinicadl/predictor/validation.py diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index ed164ea0c..f8f3bca9a 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -124,7 +124,6 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]: "mode": "fixed", "multi_cohort": "fixed", "multi_network": "choice", - "ssda_netork": "fixed", "n_fcblocks": "randint", "n_splits": "fixed", "n_proc": "fixed", diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index f4f2afe30..9e5f54657 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -4,7 +4,6 @@ [Model] architecture = "default" # ex : Conv5_FC3 multi_network = false -ssda_network = false [Architecture] # CNN diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py index 53413fdda..59fdbaad8 100644 --- a/clinicadl/splitter/config.py +++ b/clinicadl/splitter/config.py @@ -7,8 +7,8 @@ from pydantic.types import NonNegativeInt from clinicadl.caps_dataset.data_config import DataConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.split_utils import find_splits -from clinicadl.splitter.validation import ValidationConfig logger = getLogger("clinicadl.split_config") diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/splitter.py index 3bbdde461..d39b14a5b 100644 --- a/clinicadl/splitter/splitter.py +++ b/clinicadl/splitter/splitter.py @@ -1,4 +1,5 @@ import abc +import shutil from logging import getLogger from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -6,6 +7,8 @@ import pandas as pd from clinicadl.splitter.config import SplitterConfig +from clinicadl.utils import cluster +from clinicadl.utils.exceptions import MAPSError logger = getLogger("clinicadl.split_manager") @@ -14,7 +17,7 @@ class Splitter: def __init__( self, config: SplitterConfig, - split_list: Optional[List[int]] = None, + # split_list: Optional[List[int]] = None, ): """_summary_ @@ -29,19 +32,19 @@ def __init__( """ self.config = config - self.split_list = split_list + # self.config.split.split = split_list - self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? + # self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? def max_length(self) -> int: """Maximum number of splits""" return self.config.split.n_splits def __len__(self): - if not self.split_list: + if not self.config.split.split: return self.config.split.n_splits else: - return len(self.split_list) + return len(self.config.split.split) @property def allowed_splits_list(self): @@ -203,13 +206,32 @@ def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: def split_iterator(self): """Returns an iterable to iterate on all splits wanted.""" - if not self.split_list: + + if not self.config.split.split: return range(self.config.split.n_splits) else: - return self.split_list + return self.config.split.split def _check_item(self, item): if item not in self.allowed_splits_list: raise IndexError( f"Split index {item} out of allowed splits {self.allowed_splits_list}." ) + + def check_split_list(self, maps_path, overwrite): + existing_splits = [] + for split in self.split_iterator(): + split_path = maps_path / f"split-{split}" + if split_path.is_dir(): + if overwrite: + if cluster.master: + shutil.rmtree(split_path) + else: + existing_splits.append(split) + + if len(existing_splits) > 0: + raise MAPSError( + f"Splits {existing_splits} already exist. Please " + f"specify a list of splits not intersecting the previous list, " + f"or use overwrite to erase previously trained splits." + ) diff --git a/clinicadl/maps_manager/tmp_config.py b/clinicadl/tmp_config.py similarity index 97% rename from clinicadl/maps_manager/tmp_config.py rename to clinicadl/tmp_config.py index a31af7edb..620db133e 100644 --- a/clinicadl/maps_manager/tmp_config.py +++ b/clinicadl/tmp_config.py @@ -58,6 +58,7 @@ class TmpConfig(BaseModel): arguments needed : caps_directory, maps_path, loss """ + # ??? output_size: Optional[int] = None n_classes: Optional[int] = None network_task: Optional[str] = None @@ -70,18 +71,21 @@ class TmpConfig(BaseModel): std_amp: Optional[bool] = None preprocessing_dict: Optional[dict] = None + # CALLBACKS emissions_calculator: bool = False track_exp: Optional[ExperimentTracking] = None + # COMPUTATIONAL amp: bool = False fully_sharded_data_parallel: bool = False gpu: bool = True + # SPLIT n_splits: NonNegativeInt = 0 split: Optional[Tuple[NonNegativeInt, ...]] = None tsv_path: Optional[Path] = None # not needed in predict ? - # DataConfig + # DATA caps_directory: Path baseline: bool = False diagnoses: Tuple[str, ...] = ("AD", "CN") @@ -94,55 +98,68 @@ class TmpConfig(BaseModel): data_tsv: Optional[Path] = None n_subjects: int = 300 + # DATALOADER batch_size: PositiveInt = 8 n_proc: PositiveInt = 2 sampler: Sampler = Sampler.RANDOM + # EARLY STOPPING patience: NonNegativeInt = 0 tolerance: NonNegativeFloat = 0.0 + patience_epochs: NonNegativeInt = 0 + # LEARNING RATE adaptive_learning_rate: bool = False + # MAPS MANAGER maps_path: Path data_group: Optional[str] = None overwrite: bool = False save_nifti: bool = False + # NETWORK architecture: str = "default" dropout: NonNegativeFloat = 0.0 loss: str multi_network: bool = False + # OPTIMIZATION accumulation_steps: PositiveInt = 1 epochs: PositiveInt = 20 profiler: bool = False + # OPTIMIZER learning_rate: PositiveFloat = 1e-4 optimizer: Optimizer = Optimizer.ADAM weight_decay: NonNegativeFloat = 1e-4 + # REPRODUCIBILITY compensation: Compensation = Compensation.MEMORY deterministic: bool = False save_all_models: bool = False seed: int = 0 config_file: Optional[Path] = None + # SSDA caps_target: Path = Path("") preprocessing_json_target: Path = Path("") ssda_network: bool = False tsv_target_lab: Path = Path("") tsv_target_unlab: Path = Path("") + # TRANSFER LEARNING nb_unfrozen_layer: NonNegativeInt = 0 transfer_path: Optional[Path] = None transfer_selection_metric: str = "loss" + # TRANSFORMS data_augmentation: Tuple[Transform, ...] = () train_transformations: Optional[Tuple[Transform, ...]] = None normalize: bool = True size_reduction: bool = False size_reduction_factor: SizeReductionFactor = SizeReductionFactor.TWO + # VALIDATION evaluation_steps: NonNegativeInt = 0 selection_metrics: Tuple[str, ...] = () valid_longitudinal: bool = False @@ -282,7 +299,7 @@ def adapt_cross_val_with_maps_manager_info( ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: - self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) + self.split = find_splits(maps_manager.maps_path) logger.debug(f"List of splits {self.split}") def create_groupe_df(self): diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 5e71d032e..f09021559 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -5,7 +5,7 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index bf39886d4..d4b90ee2d 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -4,7 +4,7 @@ from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ( Normalization, diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index 37e690f01..f094d5552 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -5,7 +5,7 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index c44febe6b..30a92c92a 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -14,13 +14,12 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config.lr_scheduler import LRschedulerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig -from clinicadl.config.config.ssda import SSDAConfig from clinicadl.maps_manager.config import MapsManagerConfig from clinicadl.network.config import NetworkConfig from clinicadl.optimizer.optimization import OptimizationConfig from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig @@ -50,7 +49,6 @@ class TrainConfig(BaseModel, ABC): optimizer: OptimizerConfig reproducibility: ReproducibilityConfig split: SplitConfig - ssda: SSDAConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig validation: ValidationConfig @@ -77,7 +75,6 @@ def __init__(self, **kwargs): optimizer=kwargs, reproducibility=kwargs, split=kwargs, - ssda=kwargs, transfer_learning=kwargs, transforms=kwargs, validation=kwargs, @@ -97,7 +94,6 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.optimizer.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) self.split.__dict__.update(config_dict) - self.ssda.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) self.validation.__dict__.update(config_dict) diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index dc28d0acd..e3946790c 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -603,6 +603,7 @@ def generate_sampler( network_task: Union[str, Task], dataset: CapsDataset, sampler_option: str = "random", + label_code: Optional[dict] = None, n_bins: int = 5, dp_degree: Optional[int] = None, rank: Optional[int] = None, @@ -622,7 +623,7 @@ def generate_sampler( def calculate_weights_classification(df): labels = df[dataset.config.data.label].unique() - codes = {dataset.config.data.label_code[label] for label in labels} + codes = {label_code[label] for label in labels} count = np.zeros(len(codes)) for idx in df.index: diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 66ceb0dd1..775ecd2c6 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -1,6 +1,6 @@ from __future__ import annotations # noqa: I001 -import shutil + from contextlib import nullcontext from datetime import datetime from logging import getLogger @@ -33,7 +33,8 @@ patch_to_read_json, ) from clinicadl.trainer.tasks_utils import create_training_config -from clinicadl.validator.validator import Validator +from clinicadl.predictor.predictor import Predictor +from clinicadl.predictor.config import PredictConfig from clinicadl.splitter.splitter import Splitter from clinicadl.splitter.config import SplitterConfig from clinicadl.transforms.config import TransformsConfig @@ -67,7 +68,12 @@ def __init__( self.config = config self.maps_manager = self._init_maps_manager(config) - self.validator = Validator() + predict_config = PredictConfig(**config.get_dict()) + self.validator = Predictor(predict_config) + + # test + splitter_config = SplitterConfig(**self.config.get_dict()) + self.splitter = Splitter(splitter_config) self._check_args() def _init_maps_manager(self, config) -> MapsManager: @@ -86,7 +92,12 @@ def _init_maps_manager(self, config) -> MapsManager: ) # TODO : precise which parameters in config are useful @classmethod - def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer: + def from_json( + cls, + config_file: str | Path, + maps_path: str | Path, + split: Optional[list[int]] = None, + ) -> Trainer: """ Creates a Trainer from a json configuration file. @@ -113,6 +124,7 @@ def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer: raise FileNotFoundError(f"No file found at {str(config_file)}.") config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch config_dict["maps_dir"] = maps_path + config_dict["split"] = split if split else () config_object = create_training_config(config_dict["network_task"])( **config_dict ) @@ -147,7 +159,7 @@ def from_maps(cls, maps_path: str | Path) -> Trainer: ) return cls.from_json(maps_path / "maps.json", maps_path) - def resume(self, splits: List[int]) -> None: + def resume(self) -> None: """ Resume a prematurely stopped training. @@ -157,13 +169,13 @@ def resume(self, splits: List[int]) -> None: The splits that must be resumed. """ stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) - finished_splits = set(find_finished_splits(self.maps_manager.maps_path)) - # TODO : check these two lines. Why do we need a split_manager? + finished_splits = set(find_finished_splits(self.config.maps_manager.maps_dir)) + # TODO : check these two lines. Why do we need a self.splitter? splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=splits) + self.splitter = Splitter(splitter_config) - split_iterator = split_manager.split_iterator() + split_iterator = self.splitter.split_iterator() ### absent_splits = set(split_iterator) - stopped_splits - finished_splits @@ -184,9 +196,20 @@ def resume(self, splits: List[int]) -> None: def _check_args(self): self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed) - # if (len(self.config.data.label_code) == 0): + # if len(self.config.data.label_code) == 0: # self.config.data.label_code = self.maps_manager.label_code # TODO: deal with label_code and replace self.maps_manager.label_code + from clinicadl.trainer.tasks_utils import generate_label_code + + if ( + "label_code" not in self.config.data.model_dump() + or len(self.config.data.label_code) == 0 + or self.config.data.label_code is None + ): # Allows to set custom label code in TOML + train_df = self.splitter[0]["train"] + self.config.data.label_code = generate_label_code( + self.config.network_task, train_df, self.config.data.label + ) def train( self, @@ -211,53 +234,51 @@ def train( If splits specified in input already exist and overwrite is False. """ - self.check_split_list(split_list=split_list, overwrite=overwrite) - - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=False) - - else: - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = split_manager[split] + # splitter_config = SplitterConfig(**self.config.get_dict()) + # self.splitter = Splitter(splitter_config) + # self.splitter.check_split_list(self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite) + self.splitter.check_split_list( + self.config.maps_manager.maps_dir, + overwrite, # overwrite change so careful it is not the maps manager overwrite parameters here + ) + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) - if self.config.model.multi_network: - resume, first_network = self.init_first_network(False, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=False) + split_df_dict = self.splitter[split] - def check_split_list(self, split_list, overwrite): - existing_splits = [] - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - for split in split_manager.split_iterator(): - split_path = self.maps_manager.maps_path / f"split-{split}" - if split_path.is_dir(): - if overwrite: - if cluster.master: - shutil.rmtree(split_path) - else: - existing_splits.append(split) - - if len(existing_splits) > 0: - raise MAPSError( - f"Splits {existing_splits} already exist. Please " - f"specify a list of splits not intersecting the previous list, " - f"or use overwrite to erase previously trained splits." - ) + if self.config.model.multi_network: + resume, first_network = self.init_first_network(False, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=False) + + # def check_split_list(self, split_list, overwrite): + # existing_splits = [] + # splitter_config = SplitterConfig(**self.config.get_dict()) + # self.splitter = Splitter(splitter_config) + # for split in self.splitter.split_iterator(): + # split_path = self.maps_manager.maps_path / f"split-{split}" + # if split_path.is_dir(): + # if overwrite: + # if cluster.master: + # shutil.rmtree(split_path) + # else: + # existing_splits.append(split) + + # if len(existing_splits) > 0: + # raise MAPSError( + # f"Splits {existing_splits} already exist. Please " + # f"specify a list of splits not intersecting the previous list, " + # f"or use overwrite to erase previously trained splits." + # ) def _resume( self, @@ -279,8 +300,8 @@ def _resume( """ missing_splits = [] splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - for split in split_manager.split_iterator(): + self.splitter = Splitter(splitter_config) + for split in self.splitter.split_iterator(): if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): missing_splits.append(split) @@ -290,26 +311,23 @@ def _resume( f"Please try train command on these splits and resume only others." ) - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=True) - else: - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) - split_df_dict = split_manager[split] - if self.config.model.multi_network: - resume, first_network = self.init_first_network(True, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=True) + split_df_dict = self.splitter[split] + if self.config.model.multi_network: + resume, first_network = self.init_first_network(True, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=True) def init_first_network(self, resume: bool, split: int): first_network = 0 @@ -347,7 +365,7 @@ def get_dataloader( transforms_config=self.config.transforms, multi_cohort=self.config.data.multi_cohort, label=self.config.data.label, - label_code=self.maps_manager.label_code, + label_code=self.config.data.label_code, cnn_index=cnn_index, ) if homemade_sampler: @@ -355,6 +373,7 @@ def get_dataloader( network_task=self.maps_manager.network_task, dataset=dataset, sampler_option=sampler_option, + label_code=self.config.data.label_code, dp_degree=dp_degree, rank=rank, ) @@ -452,218 +471,6 @@ def _train_single( self.maps_manager._erase_tmp(split) - def _train_ssda( - self, - split_list: Optional[List[int]] = None, - resume: bool = False, - ) -> None: - """ - Trains a single CNN for a source and target domain using semi-supervised domain adaptation. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, performs training on all splits of the cross-validation. - resume : bool (optional, default=False) - If True, the job is resumed from checkpoint. - """ - - splitter_config = SplitterConfig(**self.config.get_dict()) - - split_manager = Splitter(splitter_config, split_list=split_list) - split_manager_target_lab = Splitter(splitter_config, split_list=split_list) - - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = split_manager[split] - split_df_dict_target_lab = split_manager_target_lab[split] - - logger.debug("Loading source training data...") - data_train_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["train"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading target labelled training data...") - data_train_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), # TO CHECK - split_df_dict_target_lab["train"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - from torch.utils.data import ConcatDataset - - combined_dataset = ConcatDataset( - [data_train_source, data_train_target_labeled] - ) - - logger.debug("Loading target unlabelled training data...") - data_target_unlabeled = return_dataset( - Path(self.config.ssda.caps_target), - pd.read_csv(self.config.ssda.tsv_target_unlab, sep="\t"), - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading validation source data...") - data_valid_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["validation"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - logger.debug("Loading validation target labelled data...") - data_valid_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), - split_df_dict_target_lab["validation"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - train_source_sampler = generate_sampler( - self.maps_manager.network_task, - data_train_source, - self.config.dataloader.sampler, - ) - - logger.info( - f"Getting train and validation loader with batch size {self.config.dataloader.batch_size}" - ) - - ## Oversampling of the target dataset - from torch.utils.data import SubsetRandomSampler - - # Create index lists for target labeled dataset - labeled_indices = list(range(len(data_train_target_labeled))) - - # Oversample the indices for the target labelled dataset to match the size of the labeled source dataset - data_train_source_size = ( - len(data_train_source) // self.config.dataloader.batch_size - ) - labeled_oversampled_indices = labeled_indices * ( - data_train_source_size // len(labeled_indices) - ) - - # Append remaining indices to match the size of the largest dataset - labeled_oversampled_indices += labeled_indices[ - : data_train_source_size % len(labeled_indices) - ] - - # Create SubsetRandomSamplers using the oversampled indices - labeled_sampler = SubsetRandomSampler(labeled_oversampled_indices) - - train_source_loader = DataLoader( - data_train_source, - batch_size=self.config.dataloader.batch_size, - sampler=train_source_sampler, - # shuffle=True, # len(data_train_source) < len(data_train_target_labeled), - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - drop_last=True, - ) - logger.info( - f"Train source loader size is {len(train_source_loader)*self.config.dataloader.batch_size}" - ) - train_target_loader = DataLoader( - data_train_target_labeled, - batch_size=1, # To limit the need of oversampling - # sampler=train_target_sampler, - sampler=labeled_sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - # shuffle=True, # len(data_train_target_labeled) < len(data_train_source), - drop_last=True, - ) - logger.info( - f"Train target labeled loader size oversample is {len(train_target_loader)}" - ) - - data_train_target_labeled.df = data_train_target_labeled.df[ - ["participant_id", "session_id", "diagnosis", "cohort", "domain"] - ] - - train_target_unl_loader = DataLoader( - data_target_unlabeled, - batch_size=self.config.dataloader.batch_size, - num_workers=self.config.dataloader.n_proc, - # sampler=unlabeled_sampler, - worker_init_fn=pl_worker_init_function, - shuffle=True, - drop_last=True, - ) - - logger.info( - f"Train target unlabeled loader size is {len(train_target_unl_loader)*self.config.dataloader.batch_size}" - ) - - valid_loader_source = DataLoader( - data_valid_source, - batch_size=self.config.dataloader.batch_size, - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader source size is {len(valid_loader_source)*self.config.dataloader.batch_size}" - ) - - valid_loader_target = DataLoader( - data_valid_target_labeled, - batch_size=self.config.dataloader.batch_size, # To check - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader target size is {len(valid_loader_target)*self.config.dataloader.batch_size}" - ) - - self._train_ssdann( - train_source_loader, - train_target_loader, - train_target_unl_loader, - valid_loader_target, - valid_loader_source, - split, - resume=resume, - ) - - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) - - self.maps_manager._erase_tmp(split) - def _train( self, train_loader: DataLoader, @@ -985,412 +792,6 @@ def _train( self.callback_handler.on_train_end(parameters=self.maps_manager.parameters) - def _train_ssdann( - self, - train_source_loader: DataLoader, - train_target_loader: DataLoader, - train_target_unl_loader: DataLoader, - valid_loader: DataLoader, - valid_source_loader: DataLoader, - split: int, - network: Optional[Any] = None, - resume: bool = False, - evaluate_source: bool = True, # TO MODIFY - ): - """ - _summary_ - - Parameters - ---------- - train_source_loader : torch.utils.data.DataLoader - _description_ - train_target_loader : torch.utils.data.DataLoader - _description_ - train_target_unl_loader : torch.utils.data.DataLoader - _description_ - valid_loader : torch.utils.data.DataLoader - _description_ - valid_source_loader : torch.utils.data.DataLoader - _description_ - split : int - _description_ - network : Optional[Any] (optional, default=None) - _description_ - resume : bool (optional, default=False) - _description_ - evaluate_source : bool (optional, default=True) - _description_ - - Raises - ------ - Exception - _description_ - """ - model, beginning_epoch = self.maps_manager._init_model( - split=split, - resume=resume, - transfer_path=self.config.transfer_learning.transfer_path, - transfer_selection=self.config.transfer_learning.transfer_selection_metric, - ) - - criterion = get_criterion( - self.maps_manager.network_task, self.config.model.loss - ) - logger.debug(f"Criterion for {self.config.network_task} is {criterion}") - optimizer = self._init_optimizer(model, split=split, resume=resume) - - logger.debug(f"Optimizer used for training is optimizer") - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - train_target_unl_loader.dataset.train() - - early_stopping = EarlyStopping( - "min", - min_delta=self.config.early_stopping.tolerance, - patience=self.config.early_stopping.patience, - ) - - metrics_valid_target = {"loss": None} - metrics_valid_source = {"loss": None} - - log_writer = LogWriter( - self.maps_manager.maps_path, - evaluation_metrics(self.maps_manager.network_task) + ["loss"], - split, - resume=resume, - beginning_epoch=beginning_epoch, - network=network, - ) - epoch = log_writer.beginning_epoch - - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) - import numpy as np - - while epoch < self.config.optimization.epochs and not early_stopping.step( - metrics_valid_target["loss"] - ): - logger.info(f"Beginning epoch {epoch}.") - - model.zero_grad() - evaluation_flag, step_flag = True, True - - for i, (data_source, data_target, data_target_unl) in enumerate( - zip(train_source_loader, train_target_loader, train_target_unl_loader) - ): - p = ( - float(epoch * len(train_target_loader)) - / 10 - / len(train_target_loader) - ) - alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1 - # alpha = 0 - _, _, loss_dict = model.compute_outputs_and_loss( - data_source, data_target, data_target_unl, criterion, alpha - ) # TO CHECK - logger.debug(f"Train loss dictionary {loss_dict}") - loss = loss_dict["loss"] - loss.backward() - if (i + 1) % self.config.optimization.accumulation_steps == 0: - step_flag = False - optimizer.step() - optimizer.zero_grad() - - del loss - - # Evaluate the model only when no gradients are accumulated - if ( - self.config.validation.evaluation_steps != 0 - and (i + 1) % self.config.validation.evaluation_steps == 0 - ): - evaluation_flag = False - - # Evaluate on target data - logger.info("Evaluation on target data") - ( - _, - metrics_train_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) # TO CHECK - - ( - _, - metrics_valid_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Evaluate on source data - logger.info("Evaluation on source data") - ( - _, - metrics_train_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - ) - ( - _, - metrics_valid_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - ) - - model.train() - train_source_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - # If no step has been performed, raise Exception - if step_flag: - raise Exception( - "The model has not been updated once in the epoch. The accumulation step may be too large." - ) - - # If no evaluation has been performed, warn the user - elif evaluation_flag and self.config.validation.evaluation_steps != 0: - logger.warning( - f"Your evaluation steps {self.config.validation.evaluation_steps} are too big " - f"compared to the size of the dataset. " - f"The model is evaluated only once at the end epochs." - ) - - # Update weights one last time if gradients were computed without update - if (i + 1) % self.config.optimization.accumulation_steps != 0: - optimizer.step() - optimizer.zero_grad() - # Always test the results and save them once at the end of the epoch - model.zero_grad() - logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - - if evaluate_source: - logger.info( - f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." - ) - _, metrics_train_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - _, metrics_valid_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - _, metrics_train_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - _, metrics_valid_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Save checkpoints and best models - best_dict = retain_best.step(metrics_valid_target) - self.maps_manager._write_weights( - { - "model": model.state_dict(), - "epoch": epoch, - "name": self.config.model.architecture, - }, - best_dict, - split, - network=network, - save_all_models=False, - ) - self.maps_manager._write_weights( - { - "optimizer": optimizer.state_dict(), # TO MODIFY - "epoch": epoch, - "name": self.config.optimizer, - }, - None, - split, - filename="optimizer.pth.tar", - save_all_models=False, - ) - - epoch += 1 - - self.validator._test_loader_ssda( - self.maps_manager, - train_target_loader, - criterion, - data_group="train", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - self.validator._test_loader_ssda( - self.maps_manager, - valid_loader, - criterion, - data_group="validation", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - - if save_outputs(self.maps_manager.network_task): - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "train", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "validation", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - def _init_callbacks(self) -> None: """ Initializes training callbacks. diff --git a/clinicadl/utils/cli_param/option.py b/clinicadl/utils/cli_param/option.py index 6ff86cda2..75438ceda 100644 --- a/clinicadl/utils/cli_param/option.py +++ b/clinicadl/utils/cli_param/option.py @@ -58,13 +58,6 @@ multiple=True, default=None, ) -ssda_network = click.option( - "--ssda_network", - type=bool, - default=False, - show_default=True, - help="ssda training.", -) valid_longitudinal = click.option( "--valid_longitudinal/--valid_baseline", type=bool, diff --git a/clinicadl/utils/iotools/train_utils.py b/clinicadl/utils/iotools/train_utils.py index e4347de3b..71595811d 100644 --- a/clinicadl/utils/iotools/train_utils.py +++ b/clinicadl/utils/iotools/train_utils.py @@ -198,3 +198,65 @@ def merge_cli_and_config_file_options(task: Task, **kwargs) -> Dict[str, Any]: pass ### return options + + +def merge_cli_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Any]: + """ + Merges options from the CLI (passed by the user) and from the config file + (if it exists). + + Priority is given to options passed by the user via the CLI. If it is not + provided, it will look for the option in the possible config file. + If an option is not passed by the user and not found in the config file, it will + not be in the output. + + Parameters + ---------- + task : Task + The task that is performed (e.g. classification). + + Returns + ------- + Dict[str, Any] + A dictionary with training options. + """ + from clinicadl.caps_dataset.caps_dataset_utils import read_json + + options = read_json(maps_json) + for arg in kwargs: + if ( + click.get_current_context().get_parameter_source(arg) + == ParameterSource.COMMANDLINE + ): + options[arg] = kwargs[arg] + + return options + + +def merge_options_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Any]: + """ + Merges options from the CLI (passed by the user) and from the config file + (if it exists). + + Priority is given to options passed by the user via the CLI. If it is not + provided, it will look for the option in the possible config file. + If an option is not passed by the user and not found in the config file, it will + not be in the output. + + Parameters + ---------- + task : Task + The task that is performed (e.g. classification). + + Returns + ------- + Dict[str, Any] + A dictionary with training options. + """ + from clinicadl.caps_dataset.caps_dataset_utils import read_json + + options = read_json(maps_json) + for arg in kwargs: + options[arg] = kwargs[arg] + + return options diff --git a/clinicadl/utils/iotools/trainer_utils.py b/clinicadl/utils/iotools/trainer_utils.py index b77229ea6..ac1b6a3bf 100644 --- a/clinicadl/utils/iotools/trainer_utils.py +++ b/clinicadl/utils/iotools/trainer_utils.py @@ -19,8 +19,7 @@ def create_parameters_dict(config): parameters["transfer_path"] = False if parameters["data_augmentation"] == (): parameters["data_augmentation"] = False - parameters["preprocessing_dict_target"] = parameters["preprocessing_json_target"] - del parameters["preprocessing_json_target"] + del parameters["preprocessing_json"] # if "tsv_path" in parameters: # parameters["tsv_path"] = parameters["tsv_path"] diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py deleted file mode 100644 index 2f8c8a30a..000000000 --- a/clinicadl/validator/config.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union - -from pydantic import ( - BaseModel, - ConfigDict, - computed_field, - field_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - - -class ValidatorConfig(BaseModel): - """Base config class to configure the validator.""" - - maps_path: Path - mode: str - network_task: str - num_networks: Optional[int] = None - fsdp: Optional[bool] = None - amp: Optional[bool] = None - metrics_module: Optional = None - n_classes: Optional[int] = None - nb_unfrozen_layers: Optional[int] = None - std_amp: Optional[bool] = None - - # pydantic config - model_config = ConfigDict( - validate_assignment=True, - use_enum_values=True, - validate_default=True, - ) - - @computed_field - @property - @abstractmethod - def metric(self) -> str: - """The name of the metric.""" - - @field_validator("get_not_nans", mode="after") - @classmethod - def validator_get_not_nans(cls, v): - assert not v, "get_not_nans not supported in ClinicaDL. Please set to False." - - return v diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py deleted file mode 100644 index c8f5e9451..000000000 --- a/clinicadl/validator/validator.py +++ /dev/null @@ -1,496 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -import torch.distributed as dist -from torch.amp import autocast -from torch.nn.modules.loss import _Loss -from torch.utils.data import DataLoader - -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.metrics.utils import find_selection_metrics -from clinicadl.network.network import Network -from clinicadl.trainer.tasks_utils import columns, compute_metrics, generate_test_row -from clinicadl.utils import cluster -from clinicadl.utils.computational.ddp import DDP, init_ddp -from clinicadl.utils.enum import ( - ClassificationLoss, - ClassificationMetric, - ReconstructionLoss, - ReconstructionMetric, - RegressionLoss, - RegressionMetric, - Task, -) -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLConfigurationError, - MAPSError, -) - -logger = getLogger("clinicadl.maps_manager") -level_list: List[str] = ["warning", "info", "debug"] - - -# TODO save weights on CPU for better compatibility - - -class Validator: - def test( - self, - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task, - model: Network, - dataloader: DataLoader, - criterion: _Loss, - use_labels: bool = True, - amp: bool = False, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Parameters - ---------- - model: Network - The model trained. - dataloader: DataLoader - Wrapper of a CapsDataset. - criterion: _Loss - Function to calculate the loss. - use_labels: bool - If True the true_label will be written in output DataFrame - and metrics dict will be created. - amp: bool - If True, enables Pytorch's automatic mixed precision. - - Returns - ------- - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = {} - with torch.no_grad(): - for i, data in enumerate(dataloader): - # initialize the loss list to save the loss components - with autocast("cuda", enabled=amp): - outputs, loss_dict = model(data, criterion, use_labels=use_labels) - - if i == 0: - for loss_component in loss_dict.keys(): - total_loss[loss_component] = 0 - for loss_component in total_loss.keys(): - total_loss[loss_component] += loss_dict[loss_component].float() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, - mode, - metrics_module, - n_classes, - idx, - data, - outputs.float(), - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - dataframes = [None] * dist.get_world_size() - dist.gather_object( - results_df, dataframes if dist.get_rank() == 0 else None, dst=0 - ) - if dist.get_rank() == 0: - results_df = pd.concat(dataframes) - del dataframes - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - for loss_component in total_loss.keys(): - dist.reduce(total_loss[loss_component], dst=0) - loss_value = total_loss[loss_component].item() / cluster.world_size - - if report_ci: - metrics_dict["Metric_names"].append(loss_component) - metrics_dict["Metric_values"].append(loss_value) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict[loss_component] = loss_value - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def test_da( - self, - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task: Union[str, Task], - model: Network, - dataloader: DataLoader, - criterion: _Loss, - alpha: float = 0, - use_labels: bool = True, - target: bool = True, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Args: - model: the model trained. - dataloader: wrapper of a CapsDataset. - criterion: function to calculate the loss. - use_labels: If True the true_label will be written in output DataFrame - and metrics dict will be created. - Returns: - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = 0 - with torch.no_grad(): - for i, data in enumerate(dataloader): - outputs, loss_dict = model.compute_outputs_and_loss_test( - data, criterion, alpha, target - ) - total_loss += loss_dict["loss"].item() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, - mode, - metrics_module, - n_classes, - idx, - data, - outputs, - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - if report_ci: - metrics_dict["Metric_names"].append("loss") - metrics_dict["Metric_values"].append(total_loss) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict["loss"] = total_loss - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def _test_loader( - self, - maps_manager: MapsManager, - dataloader, - criterion, - data_group: str, - split: int, - selection_metrics, - use_labels=True, - gpu=None, - amp=False, - network=None, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - if cluster.master: - log_dir = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - ) - maps_manager.write_description_log( - log_dir, - data_group, - dataloader.dataset.config.data.caps_dict, - dataloader.dataset.config.data.data_df, - ) - - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - model = DDP( - model, - fsdp=maps_manager.fully_sharded_data_parallel, - amp=maps_manager.amp, - ) - - prediction_df, metrics = self.test( - mode=maps_manager.mode, - metrics_module=maps_manager.metrics_module, - n_classes=maps_manager.n_classes, - network_task=maps_manager.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - use_labels=use_labels, - amp=amp, - report_ci=report_ci, - ) - if use_labels: - if network is not None: - metrics[f"{maps_manager.mode}_id"] = network - - loss_to_log = ( - metrics["Metric_values"][-1] if report_ci else metrics["loss"] - ) - - logger.info( - f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - if cluster.master: - # Replace here - maps_manager._mode_level_to_tsv( - prediction_df, - metrics, - split, - selection_metric, - data_group=data_group, - ) - - def _test_loader_ssda( - self, - maps_manager: MapsManager, - dataloader, - criterion, - alpha, - data_group, - split, - selection_metrics, - use_labels=True, - gpu=None, - network=None, - target=False, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - log_dir = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - ) - maps_manager.write_description_log( - log_dir, - data_group, - dataloader.dataset.caps_dict, - dataloader.dataset.df, - ) - - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - prediction_df, metrics = self.test_da( - network_task=maps_manager.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - target=target, - report_ci=report_ci, - mode=maps_manager.mode, - metrics_module=maps_manager.metrics_module, - n_classes=maps_manager.n_classes, - ) - if use_labels: - if network is not None: - metrics[f"{maps_manager.mode}_id"] = network - - if report_ci: - loss_to_log = metrics["Metric_values"][-1] - else: - loss_to_log = metrics["loss"] - - logger.info( - f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - # Replace here - maps_manager._mode_level_to_tsv( - prediction_df, metrics, split, selection_metric, data_group=data_group - ) - - @torch.no_grad() - def _compute_output_tensors( - self, - maps_manager: MapsManager, - dataset, - data_group, - split, - selection_metrics, - nb_images=None, - gpu=None, - network=None, - ): - """ - Compute the output tensors and saves them in the MAPS. - - Args: - dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. - data_group (str): name of the data group used for the task. - split (int): split number. - selection_metrics (list[str]): metrics used for model selection. - nb_images (int): number of full images to write. Default computes the outputs of the whole data set. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, - ) - model = DDP( - model, - fsdp=maps_manager.fully_sharded_data_parallel, - amp=maps_manager.amp, - ) - model.eval() - - tensor_path = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - / "tensors" - ) - if cluster.master: - tensor_path.mkdir(parents=True, exist_ok=True) - dist.barrier() - - if nb_images is None: # Compute outputs for the whole data set - nb_modes = len(dataset) - else: - nb_modes = nb_images * dataset.elem_per_image - - for i in [ - *range(cluster.rank, nb_modes, cluster.world_size), - *range(int(nb_modes % cluster.world_size <= cluster.rank)), - ]: - data = dataset[i] - image = data["image"] - x = image.unsqueeze(0).to(model.device) - with autocast("cuda", enabled=maps_manager.std_amp): - output = model(x) - output = output.squeeze(0).cpu().float() - participant_id = data["participant_id"] - session_id = data["session_id"] - mode_id = data[f"{maps_manager.mode}_id"] - input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" - output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" - torch.save(image, tensor_path / input_filename) - torch.save(output, tensor_path / output_filename) - logger.debug(f"File saved at {[input_filename, output_filename]}") - - def _ensemble_prediction( - self, - maps_manager: MapsManager, - data_group, - split, - selection_metrics, - use_labels=True, - skip_leak_check=False, - ): - """Computes the results on the image-level.""" - - if not selection_metrics: - selection_metrics = find_selection_metrics(maps_manager.maps_path, split) - - for selection_metric in selection_metrics: - ##################### - # Soft voting - if maps_manager.num_networks > 1 and not skip_leak_check: - maps_manager._ensemble_to_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) - elif maps_manager.mode != "image" and not skip_leak_check: - maps_manager._mode_to_image_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 7b4c9358b..ef6f394f8 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -7,7 +7,7 @@ import pytest from clinicadl.interpret.config import InterpretConfig -from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predictor.predictor import Predictor @pytest.fixture(params=["classification", "regression"]) @@ -77,14 +77,21 @@ def run_interpret(cnn_input, tmp_out_dir, ref_dir): assert train_error for method in list(InterpretationMethod): - interpret_config = InterpretConfig( - maps_dir=maps_path, - data_group="train", - name=f"test-{method}", - method_cls=method, + from clinicadl.utils.iotools.train_utils import ( + merge_options_and_maps_json_options, ) - interpret_manager = PredictManager(interpret_config) + + dict_ = { + "maps_dir": maps_path, + "data_group": "train", + "name": f"test-{method}", + "method_cls": method, + } + # options = merge_options_and_maps_json_options(maps_path / "maps.json", **dict_) + interpret_config = InterpretConfig(**dict_) + + interpret_manager = Predictor(interpret_config) interpret_manager.interpret() interpret_map = interpret_manager.get_interpretation( - "train", f"test-{interpret_config.method}" + "train", f"test-{interpret_config.interpret.method}" ) diff --git a/tests/test_predict.py b/tests/test_predict.py index 849f0e20d..e515ef41c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -7,8 +7,8 @@ import pytest from clinicadl.metrics.utils import get_metrics -from clinicadl.predict.predict_manager import PredictManager -from clinicadl.predict.utils import get_prediction +from clinicadl.predictor.predictor import Predictor +from clinicadl.predictor.utils import get_prediction from .testing_tools import compare_folders, modify_maps @@ -101,7 +101,7 @@ def test_predict(cmdopt, tmp_path, test_name): # with open(json_path, "w") as f: # f.write(json_data) - from clinicadl.predict.config import PredictConfig + from clinicadl.predictor.config import PredictConfig predict_config = PredictConfig( maps_dir=model_folder, @@ -113,7 +113,7 @@ def test_predict(cmdopt, tmp_path, test_name): overwrite=True, diagnoses=["CN"], ) - predict_manager = PredictManager(predict_config) + predict_manager = Predictor(predict_config) predict_manager.predict() for mode in modes: diff --git a/tests/test_train_ae.py b/tests/test_train_ae.py index c7fbcb276..d4611e188 100644 --- a/tests/test_train_ae.py +++ b/tests/test_train_ae.py @@ -107,6 +107,11 @@ def test_train_ae(cmdopt, tmp_path, test_name): no_gpu=cmdopt["no-gpu"], adapt_base_dir=cmdopt["adapt-base-dir"], ) + json_data_out = modify_maps( + maps=json_data_out, + base_dir=base_dir, + ssda=True, + ) assert json_data_out == json_data_ref # ["mode"] == mode assert compare_folders( diff --git a/tests/test_train_cnn.py b/tests/test_train_cnn.py index 761fedbee..2a29a3166 100644 --- a/tests/test_train_cnn.py +++ b/tests/test_train_cnn.py @@ -125,7 +125,13 @@ def test_train_cnn(cmdopt, tmp_path, test_name): base_dir=base_dir, no_gpu=cmdopt["no-gpu"], adapt_base_dir=cmdopt["adapt-base-dir"], + ssda=True, ) + json_data_out = modify_maps( + maps=json_data_out, + base_dir=base_dir, + ssda=True, + ) assert json_data_out == json_data_ref # ["mode"] == mode assert compare_folders( diff --git a/tests/test_train_from_json.py b/tests/test_train_from_json.py index 06b307b0f..f1bdaff01 100644 --- a/tests/test_train_from_json.py +++ b/tests/test_train_from_json.py @@ -74,6 +74,7 @@ def test_determinism(cmdopt, tmp_path): # Reproduce experiment (train from json) config_json = tmp_out_dir / "maps_roi_cnn/maps.json" + flag_error = not system( f"clinicadl train from_json {str(config_json)} {str(reproduced_maps_dir)} -s 0" ) diff --git a/tests/test_transfer_learning.py b/tests/test_transfer_learning.py index d49cbd61f..6a7850f9b 100644 --- a/tests/test_transfer_learning.py +++ b/tests/test_transfer_learning.py @@ -169,6 +169,7 @@ def test_transfer_learning(cmdopt, tmp_path, test_name): json_data_ref["gpu"] = json_data_out["gpu"] json_data_ref["transfer_path"] = json_data_out["transfer_path"] json_data_ref["tsv_path"] = json_data_out["tsv_path"] + json_data_out["ssda_network"] = json_data_ref["ssda_network"] ### assert json_data_out == json_data_ref # ["mode"] == mode diff --git a/tests/testing_tools.py b/tests/testing_tools.py index 4044d1022..885096374 100644 --- a/tests/testing_tools.py +++ b/tests/testing_tools.py @@ -174,6 +174,8 @@ def modify_maps( base_dir: Path, no_gpu: bool = False, adapt_base_dir: bool = False, + modify_split: bool = False, + ssda: bool = False, ) -> Dict[str, Any]: """ Modifies a MAPS dictionary if the user passed --no-gpu or --adapt-base-dir flags. @@ -208,6 +210,12 @@ def modify_maps( ) except KeyError: # maps with only caps directory pass + + if modify_split: + maps["split"] = (0,) + + if ssda: + maps["ssda_network"] = False return maps diff --git a/tests/unittests/nn/networks/test_ssda.py b/tests/unittests/nn/networks/test_ssda.py deleted file mode 100644 index 06da85ff2..000000000 --- a/tests/unittests/nn/networks/test_ssda.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from clinicadl.nn.networks.ssda import Conv5_FC3_SSDA - - -def test_UNet(): - input_ = torch.randn(2, 1, 64, 63, 62) - network = Conv5_FC3_SSDA(input_size=(1, 64, 63, 62), output_size=3) - output = network(input_) - for out in output: - assert out.shape == torch.Size((2, 3)) diff --git a/tests/unittests/train/test_utils.py b/tests/unittests/train/test_utils.py index 6b33787eb..2914d2d9b 100644 --- a/tests/unittests/train/test_utils.py +++ b/tests/unittests/train/test_utils.py @@ -7,7 +7,6 @@ expected_classification = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -65,7 +64,6 @@ expected_regression = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -121,7 +119,6 @@ expected_reconstruction = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 503b88ddf..c6b130cb8 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -5,9 +5,8 @@ from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.ssda import SSDAConfig from clinicadl.network.config import NetworkConfig -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig @@ -70,31 +69,6 @@ def test_model_config(): ) -def test_ssda_config(caps_example): - preprocessing_json_target = ( - caps_example / "tensor_extraction" / "preprocessing.json" - ) - c = SSDAConfig( - ssda_network=True, - preprocessing_json_target=preprocessing_json_target, - ) - expected_preprocessing_dict = { - "preprocessing": "t1-linear", - "mode": "image", - "use_uncropped_image": False, - "prepare_dl": False, - "extract_json": "t1-linear_mode-image.json", - "file_type": { - "pattern": "*space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz", - "description": "T1W Image registered using t1-linear and cropped (matrix size 169\u00d7208\u00d7179, 1 mm isotropic voxels)", - "needed_pipeline": "t1-linear", - }, - } - assert c.preprocessing_dict_target == expected_preprocessing_dict - c = SSDAConfig() - assert c.preprocessing_dict_target == {} - - def test_transferlearning_config(): c = TransferLearningConfig(transfer_path=False) assert c.transfer_path is None