From 68a30e1fa90669c08f0f4ff5ec62bc8efcd3a61a Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:05:54 +0200 Subject: [PATCH] Extract Splitter from the MapsManager (#661) * remove __init_split_manager * remove Kfold and SingleSplit * init split manager to init splitter * add a SplitterConfig * remove split_name parameters --- clinicadl/API_test.py | 13 + clinicadl/caps_dataset/data_config.py | 3 +- clinicadl/commandline/arguments.py | 4 +- .../{cross_validation.py => split.py} | 10 +- .../commandline/modules_options/validation.py | 2 +- .../commandline/pipelines/predict/cli.py | 4 +- .../pipelines/train/classification/cli.py | 10 +- .../pipelines/train/from_json/cli.py | 4 +- .../pipelines/train/reconstruction/cli.py | 10 +- .../pipelines/train/regression/cli.py | 10 +- .../commandline/pipelines/train/resume/cli.py | 4 +- clinicadl/interpret/config.py | 6 +- clinicadl/maps_manager/maps_manager.py | 201 ++++++------- clinicadl/maps_manager/tmp_config.py | 5 +- clinicadl/metrics/utils.py | 21 +- clinicadl/predict/config.py | 6 +- clinicadl/predict/predict_manager.py | 20 +- clinicadl/predict/utils.py | 11 +- clinicadl/random_search/random_search.py | 4 +- .../random_search/random_search_config.py | 2 +- clinicadl/splitter/__init__.py | 0 clinicadl/splitter/config.py | 71 +++++ clinicadl/splitter/split_utils.py | 60 +--- clinicadl/splitter/splitter.py | 215 +++++++++++++ .../{validation => splitter}/validation.py | 5 +- clinicadl/trainer/config/classification.py | 4 +- clinicadl/trainer/config/reconstruction.py | 4 +- clinicadl/trainer/config/regression.py | 4 +- clinicadl/trainer/config/train.py | 23 +- clinicadl/trainer/trainer.py | 240 ++++----------- .../utils/early_stopping/early_stopping.py | 2 +- clinicadl/utils/iotools/trainer_utils.py | 7 +- clinicadl/utils/meta_maps/getter.py | 4 +- clinicadl/validation/cross_validation.py | 41 --- .../validation/split_manager/__init__.py | 2 - clinicadl/validation/split_manager/kfold.py | 52 ---- .../validation/split_manager/single_split.py | 44 --- .../validation/split_manager/split_manager.py | 284 ------------------ clinicadl/validator/config.py | 1 - clinicadl/validator/validator.py | 10 +- tests/test_predict.py | 2 - tests/test_resume.py | 6 +- .../test_random_search_config.py | 2 +- .../test_classification_config.py | 2 +- .../test_reconstruction_config.py | 2 +- .../regression/test_regression_config.py | 2 +- .../train/trainer/test_training_config.py | 20 +- 47 files changed, 567 insertions(+), 892 deletions(-) rename clinicadl/commandline/modules_options/{cross_validation.py => split.py} (64%) create mode 100644 clinicadl/splitter/__init__.py create mode 100644 clinicadl/splitter/config.py create mode 100644 clinicadl/splitter/splitter.py rename clinicadl/{validation => splitter}/validation.py (86%) delete mode 100644 clinicadl/validation/cross_validation.py delete mode 100644 clinicadl/validation/split_manager/__init__.py delete mode 100644 clinicadl/validation/split_manager/kfold.py delete mode 100644 clinicadl/validation/split_manager/single_split.py delete mode 100644 clinicadl/validation/split_manager/split_manager.py diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index 5f17c044c..0581b879a 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -68,3 +68,16 @@ self.config.validation.selection_metrics, ) ###### end ############ + + +for split in splitter.split_iterator(): + for network in range( + first_network, self.maps_manager.num_networks + ): # for multi_network + ###### actual _train_single method of the Trainer ############ + test_loader = trainer.get_dataloader(dataset, split, network, "test", config) + validator.predict(test_loader) + +interpret_config = InterpretConfig(**kwargs) +predict_manager = PredictManager(interpret_config) +predict_manager.interpret() diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/caps_dataset/data_config.py index 5fdeb568e..35aed91b5 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/caps_dataset/data_config.py @@ -10,6 +10,7 @@ ClinicaDLArgumentError, ClinicaDLTSVError, ) +from clinicadl.utils.iotools.clinica_utils import check_caps_folder from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test from clinicadl.utils.iotools.utils import read_preprocessing @@ -85,8 +86,6 @@ def check_data_tsv(cls, v) -> Path: @computed_field @property def caps_dict(self) -> Dict[str, Path]: - from clinicadl.utils.iotools.clinica_utils import check_caps_folder - if self.multi_cohort: if self.caps_directory.suffix != ".tsv": raise ClinicaDLArgumentError( diff --git a/clinicadl/commandline/arguments.py b/clinicadl/commandline/arguments.py index 2d85b1fb0..76c6ad8c6 100644 --- a/clinicadl/commandline/arguments.py +++ b/clinicadl/commandline/arguments.py @@ -19,9 +19,7 @@ merged_tsv = click.argument("merged_tsv", type=click.Path(exists=True, path_type=Path)) # TSV TOOLS -tsv_directory = click.argument( - "tsv_directory", type=click.Path(exists=True, path_type=Path) -) +tsv_path = click.argument("tsv_path", type=click.Path(exists=True, path_type=Path)) old_tsv_dir = click.argument( "old_tsv_dir", type=click.Path(exists=True, path_type=Path) ) diff --git a/clinicadl/commandline/modules_options/cross_validation.py b/clinicadl/commandline/modules_options/split.py similarity index 64% rename from clinicadl/commandline/modules_options/cross_validation.py rename to clinicadl/commandline/modules_options/split.py index c1c745ce3..f7c0a8882 100644 --- a/clinicadl/commandline/modules_options/cross_validation.py +++ b/clinicadl/commandline/modules_options/split.py @@ -2,13 +2,13 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.validation.cross_validation import CrossValidationConfig +from clinicadl.splitter.config import SplitConfig # Cross Validation n_splits = click.option( "--n_splits", - type=get_type("n_splits", CrossValidationConfig), - default=get_default("n_splits", CrossValidationConfig), + type=get_type("n_splits", SplitConfig), + default=get_default("n_splits", SplitConfig), help="If a value is given for k will load data of a k-fold CV. " "Default value (0) will load a single split.", show_default=True, @@ -16,8 +16,8 @@ split = click.option( "--split", "-s", - type=int, # get_type("split", config.CrossValidationConfig), - default=get_default("split", CrossValidationConfig), + type=int, # get_type("split", config.ValidationConfig), + default=get_default("split", SplitConfig), multiple=True, help="Train the list of given splits. By default, all the splits are trained.", show_default=True, diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py index 4e2e973e3..858dd956e 100644 --- a/clinicadl/commandline/modules_options/validation.py +++ b/clinicadl/commandline/modules_options/validation.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.validation.validation import ValidationConfig +from clinicadl.splitter.validation import ValidationConfig # Validation valid_longitudinal = click.option( diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index c4cdaf1a1..fa7303008 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -3,10 +3,10 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( computational, - cross_validation, data, dataloader, maps_manager, + split, validation, ) from clinicadl.commandline.pipelines.predict import options @@ -29,7 +29,7 @@ @data.diagnoses @validation.skip_leak_check @validation.selection_metrics -@cross_validation.split +@split.split @computational.gpu @computational.amp @dataloader.n_proc diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index d552c318b..539f6cd42 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -33,7 +33,7 @@ # Mandatory arguments @arguments.caps_directory @arguments.preprocessing_json -@arguments.tsv_directory +@arguments.tsv_path @arguments.output_maps # Options # Computational @@ -70,8 +70,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay @@ -115,4 +115,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) config = ClassificationConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.cross_validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py index ab613d330..c0130a9b9 100644 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ b/clinicadl/commandline/pipelines/train/from_json/cli.py @@ -4,7 +4,7 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( - cross_validation, + split, ) from clinicadl.trainer.trainer import Trainer @@ -12,7 +12,7 @@ @click.command(name="from_json", no_args_is_help=True) @arguments.config_file @arguments.output_maps -@cross_validation.split +@split.split def cli(**kwargs): """ Replicate a deep learning training based on a previously created JSON file. diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index d0a40fa40..d63bf63f8 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -33,7 +33,7 @@ # Mandatory arguments @arguments.caps_directory @arguments.preprocessing_json -@arguments.tsv_directory +@arguments.tsv_path @arguments.output_maps # Options # Computational @@ -70,8 +70,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay @@ -112,4 +112,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.RECONSTRUCTION, **kwargs) config = ReconstructionConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.cross_validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index ffeb218a8..ff6dd68ca 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -31,7 +31,7 @@ # Mandatory arguments @arguments.caps_directory @arguments.preprocessing_json -@arguments.tsv_directory +@arguments.tsv_path @arguments.output_maps # Options # Computational @@ -68,8 +68,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay @@ -111,4 +111,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.REGRESSION, **kwargs) config = RegressionConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.cross_validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index 8734bf95d..1fc34a0f4 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -2,14 +2,14 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( - cross_validation, + split, ) from clinicadl.trainer.trainer import Trainer @click.command(name="resume", no_args_is_help=True) @arguments.input_maps -@cross_validation.split +@split.split def cli(input_maps_directory, split): """Resume training job in specified maps. diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index abbf89b64..41c8dcea9 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -8,10 +8,10 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.validation import ValidationConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod -from clinicadl.validation.cross_validation import CrossValidationConfig -from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.interpret_config") @@ -49,8 +49,8 @@ class InterpretConfig( InterpretBaseConfig, DataConfig, ValidationConfig, - CrossValidationConfig, ComputationalConfig, DataLoaderConfig, + SplitConfig, ): """Config class to perform Transfer Learning.""" diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 3b32486b5..76cb544fe 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -1,4 +1,5 @@ import json +import shutil import subprocess from datetime import datetime from logging import getLogger @@ -17,6 +18,8 @@ check_selection_metric, ) from clinicadl.predict.utils import get_prediction +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.tasks_utils import ( ensemble_prediction, evaluation_metrics, @@ -48,7 +51,7 @@ def __init__( self, maps_path: Path, parameters: Optional[Dict[str, Any]] = None, - verbose: str = "info", + verbose: Optional[str] = "info", ): """ @@ -104,16 +107,11 @@ def __init__( f"Please choose between classification, regression and reconstruction." ) - self.split_name = ( - self._check_split_wording() - ) # Used only for retro-compatibility - # Initiate MAPS else: self._check_args(parameters) parameters["tsv_path"] = Path(parameters["tsv_path"]) - self.split_name = "split" # Used only for retro-compatibility if cluster.master: if (maps_path.is_dir() and maps_path.is_file()) or ( # Non-folder file maps_path.is_dir() and list(maps_path.iterdir()) # Non empty folder @@ -173,8 +171,9 @@ def _check_args(self, parameters): size_reduction=self.size_reduction, size_reduction_factor=self.size_reduction_factor, ) + splitter_config = SplitterConfig(**self.parameters) + split_manager = Splitter(splitter_config) - split_manager = self._init_split_manager(None) train_df = split_manager[0]["train"] if "label" not in self.parameters: self.parameters["label"] = None @@ -320,16 +319,12 @@ def _write_training_data(self): def _write_train_val_groups(self): """Defines the training and validation groups at the initialization""" logger.debug("Writing training and validation groups...") - split_manager = self._init_split_manager() + splitter_config = SplitterConfig(**self.parameters) + split_manager = Splitter(splitter_config) for split in split_manager.split_iterator(): for data_group in ["train", "validation"]: df = split_manager[split][data_group] - group_path = ( - self.maps_path - / "groups" - / data_group - / f"{self.split_name}-{split}" - ) + group_path = self.maps_path / "groups" / data_group / f"split-{split}" group_path.mkdir(parents=True, exist_ok=True) columns = ["participant_id", "session_id", "cohort"] @@ -420,10 +415,7 @@ def _mode_level_to_tsv( data_group: the name referring to the data group on which evaluation is performed. """ performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) performance_dir.mkdir(parents=True, exist_ok=True) performance_path = ( @@ -480,7 +472,6 @@ def _ensemble_to_tsv( validation_dataset = "validation" test_df = get_prediction( self.maps_path, - self.split_name, data_group, split, selection, @@ -489,7 +480,6 @@ def _ensemble_to_tsv( ) validation_df = get_prediction( self.maps_path, - self.split_name, validation_dataset, split, selection, @@ -498,10 +488,7 @@ def _ensemble_to_tsv( ) performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) performance_dir.mkdir(parents=True, exist_ok=True) @@ -549,7 +536,6 @@ def _mode_to_image_tsv( """ sub_df = get_prediction( self.maps_path, - self.split_name, data_group, split, selection, @@ -559,10 +545,7 @@ def _mode_to_image_tsv( sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True) performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) sub_df.to_csv( performance_dir / f"{data_group}_image_level_prediction.tsv", @@ -587,13 +570,13 @@ def _mode_to_image_tsv( ############################### def _init_model( self, - transfer_path: Path = None, - transfer_selection=None, - nb_unfrozen_layer=0, - split=None, - resume=False, - gpu=None, - network=None, + transfer_path: Optional[Path] = None, + transfer_selection: Optional[str] = None, + nb_unfrozen_layer: int = 0, + split: Optional[int] = None, + resume: bool = False, + gpu: Optional[bool] = None, + network: Optional[int] = None, ): """ Instantiate the model @@ -636,10 +619,7 @@ def _init_model( if resume: checkpoint_path = ( - self.maps_path - / f"{self.split_name}-{split}" - / "tmp" - / "checkpoint.pth.tar" + self.maps_path / f"split-{split}" / "tmp" / "checkpoint.pth.tar" ) checkpoint_state = torch.load( checkpoint_path, map_location=device, weights_only=True @@ -674,72 +654,6 @@ def _init_model( return model, current_epoch - def _init_split_manager(self, split_list=None, ssda_bool: bool = False): - from clinicadl.validation import split_manager - - split_class = getattr(split_manager, self.validation) - args = list( - split_class.__init__.__code__.co_varnames[ - : split_class.__init__.__code__.co_argcount - ] - ) - args.remove("self") - args.remove("split_list") - kwargs = {"split_list": split_list} - for arg in args: - kwargs[arg] = self.parameters[arg] - - if ssda_bool: - kwargs["caps_directory"] = self.caps_target - kwargs["tsv_path"] = self.tsv_target_lab - - return split_class(**kwargs) - - def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): - # A intégrer directement dans _init_split_manager - from clinicadl.validation import split_manager - - split_class = getattr(split_manager, self.validation) - args = list( - split_class.__init__.__code__.co_varnames[ - : split_class.__init__.__code__.co_argcount - ] - ) - args.remove("self") - args.remove("split_list") - kwargs = {"split_list": split_list} - for arg in args: - kwargs[arg] = self.parameters[arg] - - kwargs["caps_directory"] = Path(caps_dir) - kwargs["tsv_path"] = Path(tsv_dir) - - return split_class(**kwargs) - - # def _init_task_manager( - # self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None - # ): - # from clinicadl.utils.task_manager import ( - # ClassificationManager, - # ReconstructionManager, - # RegressionManager, - # ) - - # if self.network_task == "classification": - # if n_classes is not None: - # return ClassificationManager(self.mode, n_classes=n_classes) - # else: - # return ClassificationManager(self.mode, df=df, label=self.label) - # elif self.network_task == "regression": - # return RegressionManager(self.mode) - # elif self.network_task == "reconstruction": - # return ReconstructionManager(self.mode) - # else: - # raise NotImplementedError( - # f"Task {self.network_task} is not implemented in ClinicaDL. " - # f"Please choose between classification, regression and reconstruction." - # ) - ############################### # Getters # ############################### @@ -758,10 +672,7 @@ def _print_description_log( selection_metric (str): Metric used for best weights selection. """ log_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection_metric}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) log_path = log_dir / "description.log" with log_path.open(mode="r") as f: @@ -831,7 +742,7 @@ def get_state_dict( (Dict): dictionary of results (weights, epoch number, metrics values) """ selection_metric = check_selection_metric( - self.maps_path, self.split_name, split, selection_metric + self.maps_path, split, selection_metric ) if self.multi_network: if network is None: @@ -841,14 +752,14 @@ def get_state_dict( else: model_path = ( self.maps_path - / f"{self.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / f"network-{network}_model.pth.tar" ) else: model_path = ( self.maps_path - / f"{self.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / "model.pth.tar" ) @@ -868,3 +779,67 @@ def std_amp(self) -> bool: then calls the internal FSDP AMP mechanisms. """ return self.amp and not self.fully_sharded_data_parallel + + def _erase_tmp(self, split: int): + """ + Erases checkpoints of the model and optimizer at the end of training. + + Parameters + ---------- + split : int + The split on which the model has been trained. + """ + tmp_path = self.maps_path / f"split-{split}" / "tmp" + shutil.rmtree(tmp_path) + + def _write_weights( + self, + state: Dict[str, Any], + metrics_dict: Optional[Dict[str, bool]], + split: int, + network: Optional[int] = None, + filename: str = "checkpoint.pth.tar", + save_all_models: bool = False, + ) -> None: + """ + Update checkpoint and save the best model according to a set of + metrics. + + Parameters + ---------- + state : Dict[str, Any] + The state of the training (model weights, epoch, etc.). + metrics_dict : Optional[Dict[str, bool]] + The output of RetainBest step. If None, only the checkpoint + is saved. + split : int + The split number. + network : int (optional, default=None) + The network number (multi-network framework). + filename : str (optional, default="checkpoint.pth.tar") + The name of the checkpoint file. + save_all_models : bool (optional, default=False) + Whether to save model weights at every epoch. + If False, only the best model will be saved. + """ + checkpoint_dir = self.maps_path / f"split-{split}" / "tmp" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + checkpoint_path = checkpoint_dir / filename + torch.save(state, checkpoint_path) + + if save_all_models: + all_models_dir = self.maps_path / f"split-{split}" / "all_models" + all_models_dir.mkdir(parents=True, exist_ok=True) + torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar") + + best_filename = "model.pth.tar" + if network is not None: + best_filename = f"network-{network}_model.pth.tar" + + # Save model according to several metrics + if metrics_dict is not None: + for metric_name, metric_bool in metrics_dict.items(): + metric_path = self.maps_path / f"split-{split}" / f"best-{metric_name}" + if metric_bool: + metric_path.mkdir(parents=True, exist_ok=True) + shutil.copyfile(checkpoint_path, metric_path / best_filename) diff --git a/clinicadl/maps_manager/tmp_config.py b/clinicadl/maps_manager/tmp_config.py index 84e9e464c..a31af7edb 100644 --- a/clinicadl/maps_manager/tmp_config.py +++ b/clinicadl/maps_manager/tmp_config.py @@ -81,6 +81,7 @@ class TmpConfig(BaseModel): split: Optional[Tuple[NonNegativeInt, ...]] = None tsv_path: Optional[Path] = None # not needed in predict ? + # DataConfig caps_directory: Path baseline: bool = False diagnoses: Tuple[str, ...] = ("AD", "CN") @@ -182,12 +183,10 @@ def check_args(self): ) if self.network_task == "classification": - from clinicadl.splitter.split_utils import init_split_manager - if self.n_splits > 1 and self.validation == "SingleSplit": self.validation = "KFoldSplit" - split_manager = init_split_manager( + split_manager = init_splitter( validation=self.validation, parameters=self.model_dump(), split_list=None, diff --git a/clinicadl/metrics/utils.py b/clinicadl/metrics/utils.py index c39cf80f8..dec32c524 100644 --- a/clinicadl/metrics/utils.py +++ b/clinicadl/metrics/utils.py @@ -7,10 +7,10 @@ from clinicadl.utils.exceptions import ClinicaDLArgumentError, MAPSError -def find_selection_metrics(maps_path: Path, split_name: str, split): +def find_selection_metrics(maps_path: Path, split): """Find which selection metrics are available in MAPS for a given split.""" - split_path = maps_path / f"{split_name}-{split}" + split_path = maps_path / f"split-{split}" if not split_path.is_dir(): raise KeyError( f"Training of split {split} was not performed." @@ -24,11 +24,9 @@ def find_selection_metrics(maps_path: Path, split_name: str, split): ] -def check_selection_metric( - maps_path: Path, split_name: str, split, selection_metric=None -): +def check_selection_metric(maps_path: Path, split, selection_metric=None): """Check that a given selection metric is available for a given split.""" - available_metrics = find_selection_metrics(maps_path, split_name, split) + available_metrics = find_selection_metrics(maps_path, split) if not selection_metric: if len(available_metrics) > 1: @@ -49,7 +47,6 @@ def check_selection_metric( def get_metrics( maps_path: Path, - split_name: str, data_group: str, split: int = 0, selection_metric: Optional[str] = None, @@ -68,15 +65,11 @@ def get_metrics( Returns: (dict[str:float]): Values of the metrics """ - selection_metric = check_selection_metric( - maps_path, split_name, split, selection_metric - ) + selection_metric = check_selection_metric(maps_path, split, selection_metric) if verbose: - print_description_log( - maps_path, split_name, data_group, split, selection_metric - ) + print_description_log(maps_path, data_group, split, selection_metric) prediction_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group + maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) if not prediction_dir.is_dir(): raise MAPSError( diff --git a/clinicadl/predict/config.py b/clinicadl/predict/config.py index 9304eefd8..a96b4b104 100644 --- a/clinicadl/predict/config.py +++ b/clinicadl/predict/config.py @@ -5,10 +5,10 @@ from clinicadl.maps_manager.config import ( MapsManagerConfig as MapsManagerBaseConfig, ) +from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.validation import ValidationConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore -from clinicadl.validation.cross_validation import CrossValidationConfig -from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.predict_config") @@ -33,8 +33,8 @@ class PredictConfig( MapsManagerConfig, DataConfig, ValidationConfig, - CrossValidationConfig, ComputationalConfig, DataLoaderConfig, + SplitConfig, ): """Config class to perform Transfer Learning.""" diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index c197a96de..55515dc8e 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -148,7 +148,6 @@ def predict( if not self._config.selection_metrics: split_selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) else: @@ -156,7 +155,7 @@ def predict( for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection}" / self._config.data_group ) @@ -497,7 +496,7 @@ def _compute_latent_tensors( model.eval() tensor_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / "latent_tensors" @@ -574,7 +573,7 @@ def _compute_output_nifti( model.eval() nifti_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / "nifti_images" @@ -719,14 +718,13 @@ def interpret(self): if not self._config.selection_metrics: self._config.selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) for selection_metric in self._config.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / f"interpret-{self._config.name}" @@ -843,13 +841,12 @@ def _check_data_group( for split in self._config.split: selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) for selection in selection_metrics: results_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection}" / self._config.data_group ) @@ -927,12 +924,12 @@ def get_group_info( "Information on train or validation data can only be " "loaded if a split number is given" ) - elif not (group_path / f"{self.maps_manager.split_name}-{split}").is_dir(): + elif not (group_path / f"split-{split}").is_dir(): raise MAPSError( f"Split {split} is not available for data group {data_group}." ) else: - group_path = group_path / f"{self.maps_manager.split_name}-{split}" + group_path = group_path / f"split-{split}" df = pd.read_csv(group_path / "data.tsv", sep="\t") json_path = group_path / "maps.json" @@ -1054,7 +1051,6 @@ def get_interpretation( selection_metric = check_selection_metric( self.maps_manager.maps_path, - self.maps_manager.split_name, split, selection_metric, ) @@ -1064,7 +1060,7 @@ def get_interpretation( ) map_dir = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group / f"interpret-{name}" diff --git a/clinicadl/predict/utils.py b/clinicadl/predict/utils.py index e547fb5b1..c66372764 100644 --- a/clinicadl/predict/utils.py +++ b/clinicadl/predict/utils.py @@ -10,7 +10,6 @@ def get_prediction( maps_path: Path, - split_name: str, data_group: str, split: int = 0, selection_metric: Optional[str] = None, @@ -31,15 +30,11 @@ def get_prediction( (DataFrame): Results indexed by columns 'participant_id' and 'session_id' which identifies the image in the BIDS / CAPS. """ - selection_metric = check_selection_metric( - maps_path, split_name, split, selection_metric - ) + selection_metric = check_selection_metric(maps_path, split, selection_metric) if verbose: - print_description_log( - maps_path, split_name, data_group, split, selection_metric - ) + print_description_log(maps_path, data_group, split, selection_metric) prediction_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group + maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) if not prediction_dir.is_dir(): raise MAPSError( diff --git a/clinicadl/random_search/random_search.py b/clinicadl/random_search/random_search.py index 196c4aea2..7929e9382 100755 --- a/clinicadl/random_search/random_search.py +++ b/clinicadl/random_search/random_search.py @@ -19,7 +19,7 @@ def launch_search(launch_directory: Path, job_name): options = get_space_dict(launch_directory) # temporary, TODO - options["tsv_directory"] = options["tsv_path"] + options["tsv_path"] = options["tsv_path"] options["maps_dir"] = maps_directory options["preprocessing_json"] = options["preprocessing_dict"]["extract_json"] @@ -38,4 +38,4 @@ def launch_search(launch_directory: Path, job_name): output_maps_directory=maps_directory, **options ) trainer = Trainer(training_config) - trainer.train(split_list=training_config.cross_validation.split, overwrite=True) + trainer.train(split_list=training_config.split.split, overwrite=True) diff --git a/clinicadl/random_search/random_search_config.py b/clinicadl/random_search/random_search_config.py index 82c8015eb..2e1d728a9 100644 --- a/clinicadl/random_search/random_search_config.py +++ b/clinicadl/random_search/random_search_config.py @@ -84,7 +84,7 @@ class TrainConfig(base_training_config): The user must specified at least the following arguments: - caps_directory - preprocessing_json - - tsv_directory + - tsv_path - output_maps_directory - convolutions_dict - n_fcblocks diff --git a/clinicadl/splitter/__init__.py b/clinicadl/splitter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py new file mode 100644 index 000000000..53413fdda --- /dev/null +++ b/clinicadl/splitter/config.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from pydantic import BaseModel, ConfigDict, field_validator +from pydantic.types import NonNegativeInt + +from clinicadl.caps_dataset.data_config import DataConfig +from clinicadl.splitter.split_utils import find_splits +from clinicadl.splitter.validation import ValidationConfig + +logger = getLogger("clinicadl.split_config") + + +class SplitConfig(BaseModel): + """ + Abstract config class for the validation procedure. + + selection_metrics is specific to the task, thus it needs + to be specified in a subclass. + """ + + n_splits: NonNegativeInt = 0 + split: Optional[Tuple[NonNegativeInt, ...]] = None + tsv_path: Optional[Path] = None # not needed in interpret ! + + # pydantic config + model_config = ConfigDict(validate_assignment=True) + + @field_validator("split", mode="before") + def validator_split(cls, v): + if isinstance(v, list): + return tuple(v) + return v # TODO : check that split exists (and check coherence with n_splits) + + def adapt_cross_val_with_maps_manager_info( + self, maps_manager + ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future + # TEMPORARY + if not self.split: + self.split = tuple(find_splits(maps_manager.maps_path)) + logger.debug(f"List of splits {self.split}") + + +class SplitterConfig(BaseModel, ABC): + """ + + Abstract config class for the training pipeline. + Some configurations are specific to the task (e.g. loss function), + thus they need to be specified in a subclass. + """ + + data: DataConfig + split: SplitConfig + validation: ValidationConfig + # pydantic config + model_config = ConfigDict(validate_assignment=True) + + def __init__(self, **kwargs): + super().__init__( + data=kwargs, + split=kwargs, + validation=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index d42047dcc..3e0f09388 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -1,43 +1,39 @@ from pathlib import Path -from typing import List, Optional +from typing import List -from clinicadl.utils.exceptions import ClinicaDLArgumentError - -def find_splits(maps_path: Path, split_name: str) -> List[int]: +def find_splits(maps_path: Path) -> List[int]: """Find which splits that were trained in the MAPS.""" splits = [ int(split.name.split("-")[1]) for split in list(maps_path.iterdir()) - if split.name.startswith(f"{split_name}-") + if split.name.startswith("split-") ] return splits -def find_stopped_splits(maps_path: Path, split_name: str) -> List[int]: +def find_stopped_splits(maps_path: Path) -> List[int]: """Find which splits for which training was not completed.""" - existing_split_list = find_splits(maps_path, split_name) + existing_split_list = find_splits(maps_path) stopped_splits = [ split for split in existing_split_list - if (maps_path / f"{split_name}-{split}" / "tmp") - in list((maps_path / f"{split_name}-{split}").iterdir()) + if (maps_path / f"split-{split}" / "tmp") + in list((maps_path / f"split-{split}").iterdir()) ] return stopped_splits -def find_finished_splits(maps_path: Path, split_name: str) -> List[int]: +def find_finished_splits(maps_path: Path) -> List[int]: """Find which splits for which training was completed.""" finished_splits = list() - existing_split_list = find_splits(maps_path, split_name) - stopped_splits = find_stopped_splits(maps_path, split_name) + existing_split_list = find_splits(maps_path) + stopped_splits = find_stopped_splits(maps_path) for split in existing_split_list: if split not in stopped_splits: performance_dir_list = [ performance_dir - for performance_dir in list( - (maps_path / f"{split_name}-{split}").iterdir() - ) + for performance_dir in list((maps_path / f"split-{split}").iterdir()) if "best-" in performance_dir.name ] if len(performance_dir_list) > 0: @@ -47,7 +43,6 @@ def find_finished_splits(maps_path: Path, split_name: str) -> List[int]: def print_description_log( maps_path: Path, - split_name: str, data_group: str, split: int, selection_metric: str, @@ -60,38 +55,7 @@ def print_description_log( split (int): Index of the split used for training. selection_metric (str): Metric used for best weights selection. """ - log_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group - ) + log_dir = maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group log_path = log_dir / "description.log" with log_path.open(mode="r") as f: content = f.read() - - -def init_split_manager( - validation, - parameters, - split_list=None, - ssda_bool: bool = False, - caps_target: Optional[Path] = None, - tsv_target_lab: Optional[Path] = None, -): - from clinicadl.validation import split_manager - - split_class = getattr(split_manager, validation) - args = list( - split_class.__init__.__code__.co_varnames[ - : split_class.__init__.__code__.co_argcount - ] - ) - args.remove("self") - args.remove("split_list") - kwargs = {"split_list": split_list} - for arg in args: - kwargs[arg] = parameters[arg] - - if ssda_bool: - kwargs["caps_directory"] = caps_target - kwargs["tsv_path"] = tsv_target_lab - - return split_class(**kwargs) diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/splitter.py new file mode 100644 index 000000000..3bbdde461 --- /dev/null +++ b/clinicadl/splitter/splitter.py @@ -0,0 +1,215 @@ +import abc +from logging import getLogger +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import pandas as pd + +from clinicadl.splitter.config import SplitterConfig + +logger = getLogger("clinicadl.split_manager") + + +class Splitter: + def __init__( + self, + config: SplitterConfig, + split_list: Optional[List[int]] = None, + ): + """_summary_ + + Parameters + ---------- + data_config : DataConfig + _description_ + validation_config : ValidationConfig + _description_ + split_list : Optional[List[int]] (optional, default=None) + _description_ + + """ + self.config = config + self.split_list = split_list + + self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? + + def max_length(self) -> int: + """Maximum number of splits""" + return self.config.split.n_splits + + def __len__(self): + if not self.split_list: + return self.config.split.n_splits + else: + return len(self.split_list) + + @property + def allowed_splits_list(self): + """ + List of possible splits if no restriction was applied + + Returns: + list[int]: list of all possible splits + """ + return [i for i in range(self.config.split.n_splits)] + + def __getitem__(self, item) -> Dict: + """ + Returns a dictionary of DataFrames with train and validation data. + + Args: + item (int): Index of the split wanted. + Returns: + Dict[str:pd.DataFrame]: dictionary with two keys (train and validation). + """ + self._check_item(item) + + if self.config.data.multi_cohort: + tsv_df = pd.read_csv(self.config.split.tsv_path, sep="\t") + train_df = pd.DataFrame() + valid_df = pd.DataFrame() + found_diagnoses = set() + for idx in range(len(tsv_df)): + cohort_name = tsv_df.at[idx, "cohort"] + cohort_path = Path(tsv_df.at[idx, "path"]) + cohort_diagnoses = ( + tsv_df.at[idx, "diagnoses"].replace(" ", "").split(",") + ) + if bool(set(cohort_diagnoses) & set(self.config.data.diagnoses)): + target_diagnoses = list( + set(cohort_diagnoses) & set(self.config.data.diagnoses) + ) + + cohort_train_df, cohort_valid_df = self.concatenate_diagnoses( + item, cohort_path=cohort_path, cohort_diagnoses=target_diagnoses + ) + cohort_train_df["cohort"] = cohort_name + cohort_valid_df["cohort"] = cohort_name + train_df = pd.concat([train_df, cohort_train_df]) + valid_df = pd.concat([valid_df, cohort_valid_df]) + found_diagnoses = found_diagnoses | ( + set(cohort_diagnoses) & set(self.config.data.diagnoses) + ) + + if found_diagnoses != set(self.config.data.diagnoses): + raise ValueError( + f"The diagnoses found in the multi cohort dataset {found_diagnoses} " + f"do not correspond to the diagnoses wanted {set(self.config.data.diagnoses)}." + ) + train_df.reset_index(inplace=True, drop=True) + valid_df.reset_index(inplace=True, drop=True) + else: + train_df, valid_df = self.concatenate_diagnoses(item) + train_df["cohort"] = "single" + valid_df["cohort"] = "single" + + return { + "train": train_df, + "validation": valid_df, + } + + @staticmethod + def get_dataframe_from_tsv_path(tsv_path: Path) -> pd.DataFrame: + df = pd.read_csv(tsv_path, sep="\t") + list_columns = df.columns.values + + if ( + "diagnosis" not in list_columns + # or "age" not in list_columns + # or "sex" not in list_columns + ): + parents_path = tsv_path.resolve().parent + labels_path = parents_path / "labels.tsv" + while ( + not labels_path.is_file() + and ((parents_path / "kfold.json").is_file()) + or (parents_path / "split.json").is_file() + ): + parents_path = parents_path.parent + try: + labels_df = pd.read_csv(labels_path, sep="\t") + df = pd.merge( + df, + labels_df, + how="inner", + on=["participant_id", "session_id"], + ) + except Exception: + pass + return df + + @staticmethod + def load_data(tsv_path: Path, cohort_diagnoses: List[str]) -> pd.DataFrame: + df = Splitter.get_dataframe_from_tsv_path(tsv_path) + df = df[df.diagnosis.isin((cohort_diagnoses))] + df.reset_index(inplace=True, drop=True) + return df + + def concatenate_diagnoses( + self, + split, + cohort_path: Optional[Path] = None, + cohort_diagnoses: Optional[List[str]] = None, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Concatenated the diagnoses needed to form the train and validation sets.""" + + if cohort_diagnoses is None: + cohort_diagnoses = list(self.config.data.diagnoses) + + tmp_cohort_path = ( + cohort_path if cohort_path is not None else self.config.split.tsv_path + ) + train_path, valid_path = self._get_tsv_paths( + tmp_cohort_path, + split, + ) + + logger.debug(f"Training data loaded at {train_path}") + if self.config.data.baseline: + train_path = train_path / "train_baseline.tsv" + else: + train_path = train_path / "train.tsv" + train_df = self.load_data(train_path, cohort_diagnoses) + + logger.debug(f"Validation data loaded at {valid_path}") + if self.config.validation.valid_longitudinal: + valid_path = valid_path / "validation.tsv" + else: + valid_path = valid_path / "validation_baseline.tsv" + valid_df = self.load_data(valid_path, cohort_diagnoses) + + return train_df, valid_df + + def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: + """ + Computes the paths to the TSV files needed depending on the split structure. + + Args: + cohort_path (str): path to the split structure of a cohort. + split (int): Index of the split. + Returns: + train_path (str): path to the directory containing training data. + valid_path (str): path to the directory containing validation data. + """ + if args is not None: + for split in args: + train_path = cohort_path / f"split-{split}" + valid_path = cohort_path / f"split-{split}" + return train_path, valid_path + else: + train_path = cohort_path + valid_path = cohort_path + return train_path, valid_path + + def split_iterator(self): + """Returns an iterable to iterate on all splits wanted.""" + if not self.split_list: + return range(self.config.split.n_splits) + else: + return self.split_list + + def _check_item(self, item): + if item not in self.allowed_splits_list: + raise IndexError( + f"Split index {item} out of allowed splits {self.allowed_splits_list}." + ) diff --git a/clinicadl/validation/validation.py b/clinicadl/splitter/validation.py similarity index 86% rename from clinicadl/validation/validation.py rename to clinicadl/splitter/validation.py index 3407a59e7..1452b47da 100644 --- a/clinicadl/validation/validation.py +++ b/clinicadl/splitter/validation.py @@ -19,11 +19,12 @@ class ValidationConfig(BaseModel): selection_metrics: Tuple[str, ...] = () valid_longitudinal: bool = False skip_leak_check: bool = False + # pydantic config model_config = ConfigDict(validate_assignment=True) @field_validator("selection_metrics", mode="before") - def list_to_tuples(cls, v): + def validator_split(cls, v): if isinstance(v, list): return tuple(v) - return v + return v # TODO : check that split exists (and check coherence with n_splits) diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 9ac3a4aae..5e71d032e 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -5,9 +5,9 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task -from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.classification_config") @@ -64,7 +64,7 @@ class ClassificationConfig(TrainConfig): The user must specified at least the following arguments: - caps_directory - preprocessing_json - - tsv_directory + - tsv_path - output_maps_directory """ diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index 08728885b..bf39886d4 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -4,6 +4,7 @@ from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ( Normalization, @@ -11,7 +12,6 @@ ReconstructionMetric, Task, ) -from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") @@ -53,7 +53,7 @@ class ReconstructionConfig(TrainConfig): The user must specified at least the following arguments: - caps_directory - preprocessing_json - - tsv_directory + - tsv_path - output_maps_directory """ diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index b19a3ba5c..37e690f01 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -5,9 +5,9 @@ from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task -from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") logger = getLogger("clinicadl.regression_config") @@ -53,7 +53,7 @@ class RegressionConfig(TrainConfig): The user must specified at least the following arguments: - caps_directory - preprocessing_json - - tsv_directory + - tsv_path - output_maps_directory """ diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index 3a3d791a8..c44febe6b 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -19,13 +19,13 @@ from clinicadl.network.config import NetworkConfig from clinicadl.optimizer.optimization import OptimizationConfig from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.splitter.config import SplitConfig +from clinicadl.splitter.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.early_stopping.config import EarlyStoppingConfig from clinicadl.utils.enum import Task -from clinicadl.validation.cross_validation import CrossValidationConfig -from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.training_config") @@ -40,7 +40,6 @@ class TrainConfig(BaseModel, ABC): callbacks: CallbacksConfig computational: ComputationalConfig - cross_validation: CrossValidationConfig data: DataConfig dataloader: DataLoaderConfig early_stopping: EarlyStoppingConfig @@ -50,6 +49,7 @@ class TrainConfig(BaseModel, ABC): optimization: OptimizationConfig optimizer: OptimizerConfig reproducibility: ReproducibilityConfig + split: SplitConfig ssda: SSDAConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig @@ -67,7 +67,6 @@ def __init__(self, **kwargs): super().__init__( callbacks=kwargs, computational=kwargs, - cross_validation=kwargs, data=kwargs, dataloader=kwargs, early_stopping=kwargs, @@ -77,6 +76,7 @@ def __init__(self, **kwargs): optimization=kwargs, optimizer=kwargs, reproducibility=kwargs, + split=kwargs, ssda=kwargs, transfer_learning=kwargs, transforms=kwargs, @@ -87,7 +87,6 @@ def _update(self, config_dict: Dict[str, Any]) -> None: """Updates the configs with a dict given by the user.""" self.callbacks.__dict__.update(config_dict) self.computational.__dict__.update(config_dict) - self.cross_validation.__dict__.update(config_dict) self.data.__dict__.update(config_dict) self.dataloader.__dict__.update(config_dict) self.early_stopping.__dict__.update(config_dict) @@ -97,6 +96,7 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.optimization.__dict__.update(config_dict) self.optimizer.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) self.ssda.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) @@ -116,3 +116,16 @@ def update_with_toml(self, path: Union[str, Path]) -> None: path = Path(path) config_dict = extract_config_from_toml_file(path, self.network_task) self._update(config_dict) + + def get_dict(self): + out_dict = {} + + def get_full_dict(input_dict_: dict, output_dict: dict): + for key, value in input_dict_.items(): + if isinstance(value, dict): + get_full_dict(value, output_dict=output_dict) + else: + output_dict[key] = value + return output_dict + + return get_full_dict(self.model_dump(), out_dict) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index b0a07bc8b..66ceb0dd1 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -5,13 +5,13 @@ from datetime import datetime from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable import pandas as pd import torch import torch.distributed as dist -from torch.amp import GradScaler -from torch.amp import autocast +from torch.amp.grad_scaler import GradScaler +from torch.amp.autocast_mode import autocast from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -34,7 +34,8 @@ ) from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.validator.validator import Validator -from clinicadl.splitter.split_utils import init_split_manager +from clinicadl.splitter.splitter import Splitter +from clinicadl.splitter.config import SplitterConfig from clinicadl.transforms.config import TransformsConfig if TYPE_CHECKING: @@ -155,23 +156,13 @@ def resume(self, splits: List[int]) -> None: splits : List[int] The splits that must be resumed. """ - stopped_splits = set( - find_stopped_splits( - self.maps_manager.maps_path, self.maps_manager.split_name - ) - ) - finished_splits = set( - find_finished_splits( - self.maps_manager.maps_path, self.maps_manager.split_name - ) - ) + stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) + finished_splits = set(find_finished_splits(self.maps_manager.maps_path)) # TODO : check these two lines. Why do we need a split_manager? - # split_manager = init_split_manager( - # validation=self.maps_manager.validation, - # parameters=self.config.model_dump(), - # split_list=splits, - # ) - split_manager = self.maps_manager._init_split_manager(split_list=splits) + + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=splits) + split_iterator = split_manager.split_iterator() ### absent_splits = set(split_iterator) - stopped_splits - finished_splits @@ -226,10 +217,9 @@ def train( self._train_ssda(split_list, resume=False) else: - split_manager = self.maps_manager._init_split_manager(split_list) - # split_manager = init_split_manager( - # self.maps_manager.validation, self.config.model_dump(), split_list - # ) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) + for split in split_manager.split_iterator(): logger.info(f"Training split {split}") seed_everything( @@ -251,14 +241,10 @@ def train( def check_split_list(self, split_list, overwrite): existing_splits = [] - split_manager = self.maps_manager._init_split_manager(split_list) - # split_manager = init_split_manager( - # self.maps_manager.validation, self.config.model_dump(), split_list - # ) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): - split_path = ( - self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" - ) + split_path = self.maps_manager.maps_path / f"split-{split}" if split_path.is_dir(): if overwrite: if cluster.master: @@ -292,16 +278,10 @@ def _resume( If splits specified in input do not exist. """ missing_splits = [] - # split_manager = init_split_manager( - # self.maps_manager.validation, self.config.model_dump(), split_list - # ) - split_manager = self.maps_manager._init_split_manager(split_list) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): - if not ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ).is_dir(): + if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): missing_splits.append(split) if len(missing_splits) > 0: @@ -331,16 +311,14 @@ def _resume( else: self._train_single(split, split_df_dict, resume=True) - def init_first_network(self, resume, split): + def init_first_network(self, resume: bool, split: int): first_network = 0 if resume: training_logs = [ - int(network_folder.split("-")[1]) + int(str(network_folder).split("-")[1]) for network_folder in list( ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "training_logs" + self.maps_manager.maps_path / f"split-{split}" / "training_logs" ).iterdir() ) ] @@ -352,40 +330,29 @@ def init_first_network(self, resume, split): def get_dataloader( self, - input_dir: Path, data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, - network_task: Union[str, Task] = "classification", sampler_option: str = "random", - n_bins: int = 5, dp_degree: Optional[int] = None, rank: Optional[int] = None, - batch_size: Optional[int] = None, - n_proc: Optional[int] = None, - worker_init_fn: Optional[function] = None, + worker_init_fn: Optional[Callable[[int], None]] = None, shuffle: Optional[bool] = None, num_replicas: Optional[int] = None, homemade_sampler: bool = False, ): dataset = return_dataset( - input_dir=input_dir, + input_dir=self.config.data.caps_directory, data_df=data_df, - preprocessing_dict=preprocessing_dict, - transforms_config=transforms_config, - multi_cohort=multi_cohort, - label=label, - label_code=label_code, + preprocessing_dict=self.config.data.preprocessing_dict, + transforms_config=self.config.transforms, + multi_cohort=self.config.data.multi_cohort, + label=self.config.data.label, + label_code=self.maps_manager.label_code, cnn_index=cnn_index, ) if homemade_sampler: sampler = generate_sampler( - network_task=network_task, + network_task=self.maps_manager.network_task, dataset=dataset, sampler_option=sampler_option, dp_degree=dp_degree, @@ -401,9 +368,9 @@ def get_dataloader( train_loader = DataLoader( dataset=dataset, - batch_size=batch_size, + batch_size=self.config.dataloader.batch_size, sampler=sampler, - num_workers=n_proc, + num_workers=self.config.dataloader.n_proc, worker_init_fn=worker_init_fn, shuffle=shuffle, ) @@ -433,20 +400,11 @@ def _train_single( logger.debug("Loading training data...") train_loader = self.get_dataloader( - input_dir=self.config.data.caps_directory, data_df=split_df_dict["train"], - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, cnn_index=network, - network_task=self.maps_manager.network_task, sampler_option=self.config.dataloader.sampler, - dp_degree=cluster.world_size, - rank=cluster.rank, - batch_size=self.config.dataloader.batch_size, - n_proc=self.config.dataloader.n_proc, + dp_degree=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore worker_init_fn=pl_worker_init_function, homemade_sampler=True, ) @@ -455,19 +413,10 @@ def _train_single( logger.debug("Loading validation data...") valid_loader = self.get_dataloader( - input_dir=self.config.data.caps_directory, data_df=split_df_dict["validation"], - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, cnn_index=network, - network_task=self.maps_manager.network_task, - num_replicas=cluster.world_size, - rank=cluster.rank, - batch_size=self.config.dataloader.batch_size, - n_proc=self.config.dataloader.n_proc, + num_replicas=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore shuffle=False, homemade_sampler=False, ) @@ -501,7 +450,7 @@ def _train_single( self.config.validation.selection_metrics, ) - self._erase_tmp(split) + self.maps_manager._erase_tmp(split) def _train_ssda( self, @@ -520,10 +469,10 @@ def _train_ssda( If True, the job is resumed from checkpoint. """ - split_manager = self.maps_manager._init_split_manager(split_list) - split_manager_target_lab = self.maps_manager._init_split_manager( - split_list, True - ) + splitter_config = SplitterConfig(**self.config.get_dict()) + + split_manager = Splitter(splitter_config, split_list=split_list) + split_manager_target_lab = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): logger.info(f"Training split {split}") @@ -713,7 +662,7 @@ def _train_ssda( self.config.validation.selection_metrics, ) - self._erase_tmp(split) + self.maps_manager._erase_tmp(split) def _train( self, @@ -722,7 +671,7 @@ def _train( split: int, network: Optional[int] = None, resume: bool = False, - callbacks: List[Callback] = [], + callbacks: list[Callback] = [], ): """ Core function shared by train and resume. @@ -792,9 +741,10 @@ def _train( beginning_epoch=beginning_epoch, network=network, ) - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) + # retain_best = RetainBest( + # selection_metrics=list(self.config.validation.selection_metrics) + # ) ??? + epoch = beginning_epoch retain_best = RetainBest( @@ -811,9 +761,7 @@ def _train( from torch.optim.lr_scheduler import ReduceLROnPlateau # Initialize the ReduceLROnPlateau scheduler - scheduler = ReduceLROnPlateau( - optimizer, mode="min", factor=0.1, verbose=True - ) + scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1) while epoch < self.config.optimization.epochs and not early_stopping.step( metrics_valid["loss"] @@ -970,14 +918,14 @@ def _train( if cluster.master: # Save checkpoints and best models best_dict = retain_best.step(metrics_valid) - self._write_weights( + self.maps_manager._write_weights( model_weights, best_dict, split, network=network, save_all_models=self.config.reproducibility.save_all_models, ) - self._write_weights( + self.maps_manager._write_weights( optimizer_weights, None, split, @@ -1375,7 +1323,7 @@ def _train_ssdann( # Save checkpoints and best models best_dict = retain_best.step(metrics_valid_target) - self._write_weights( + self.maps_manager._write_weights( { "model": model.state_dict(), "epoch": epoch, @@ -1386,7 +1334,7 @@ def _train_ssdann( network=network, save_all_models=False, ) - self._write_weights( + self.maps_manager._write_weights( { "optimizer": optimizer.state_dict(), # TO MODIFY "epoch": epoch, @@ -1504,7 +1452,7 @@ def _init_optimizer( if resume: checkpoint_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / "tmp" / "optimizer.pth.tar" ) @@ -1550,83 +1498,3 @@ def _init_profiler(self) -> torch.profiler.profile: profiler.step = lambda *args, **kwargs: None return profiler - - def _erase_tmp(self, split: int): - """ - Erases checkpoints of the model and optimizer at the end of training. - - Parameters - ---------- - split : int - The split on which the model has been trained. - """ - tmp_path = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ) - shutil.rmtree(tmp_path) - - def _write_weights( - self, - state: Dict[str, Any], - metrics_dict: Optional[Dict[str, bool]], - split: int, - network: Optional[int] = None, - filename: str = "checkpoint.pth.tar", - save_all_models: bool = False, - ) -> None: - """ - Update checkpoint and save the best model according to a set of - metrics. - - Parameters - ---------- - state : Dict[str, Any] - The state of the training (model weights, epoch, etc.). - metrics_dict : Optional[Dict[str, bool]] - The output of RetainBest step. If None, only the checkpoint - is saved. - split : int - The split number. - network : int (optional, default=None) - The network number (multi-network framework). - filename : str (optional, default="checkpoint.pth.tar") - The name of the checkpoint file. - save_all_models : bool (optional, default=False) - Whether to save model weights at every epoch. - If False, only the best model will be saved. - """ - checkpoint_dir = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ) - checkpoint_dir.mkdir(parents=True, exist_ok=True) - checkpoint_path = checkpoint_dir / filename - torch.save(state, checkpoint_path) - - if save_all_models: - all_models_dir = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "all_models" - ) - all_models_dir.mkdir(parents=True, exist_ok=True) - torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar") - - best_filename = "model.pth.tar" - if network is not None: - best_filename = f"network-{network}_model.pth.tar" - - # Save model according to several metrics - if metrics_dict is not None: - for metric_name, metric_bool in metrics_dict.items(): - metric_path = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / f"best-{metric_name}" - ) - if metric_bool: - metric_path.mkdir(parents=True, exist_ok=True) - shutil.copyfile(checkpoint_path, metric_path / best_filename) diff --git a/clinicadl/utils/early_stopping/early_stopping.py b/clinicadl/utils/early_stopping/early_stopping.py index 8ea73bd5e..73a2b67cf 100644 --- a/clinicadl/utils/early_stopping/early_stopping.py +++ b/clinicadl/utils/early_stopping/early_stopping.py @@ -1,5 +1,5 @@ class EarlyStopping(object): - def __init__(self, mode="min", min_delta=0, patience=10): + def __init__(self, mode: str = "min", min_delta: float = 0, patience: int = 10): self.mode = mode self.min_delta = min_delta self.patience = patience diff --git a/clinicadl/utils/iotools/trainer_utils.py b/clinicadl/utils/iotools/trainer_utils.py index 58c8103ef..b77229ea6 100644 --- a/clinicadl/utils/iotools/trainer_utils.py +++ b/clinicadl/utils/iotools/trainer_utils.py @@ -22,8 +22,9 @@ def create_parameters_dict(config): parameters["preprocessing_dict_target"] = parameters["preprocessing_json_target"] del parameters["preprocessing_json_target"] del parameters["preprocessing_json"] - parameters["tsv_path"] = parameters["tsv_directory"] - del parameters["tsv_directory"] + # if "tsv_path" in parameters: + # parameters["tsv_path"] = parameters["tsv_path"] + # del parameters["tsv_path"] parameters["compensation"] = parameters["compensation"].value parameters["size_reduction_factor"] = parameters["size_reduction_factor"].value if parameters["track_exp"]: @@ -72,7 +73,7 @@ def create_parameters_dict(config): def patch_to_read_json(config_dict): - config_dict["tsv_directory"] = config_dict["tsv_path"] + config_dict["tsv_path"] = config_dict["tsv_path"] if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""): config_dict["track_exp"] = None config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][ diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 6ea53eac9..1fa524950 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -36,20 +36,18 @@ def meta_maps_analysis(launch_dir: Path, evaluation_metric="loss"): for job in jobs_list: performances_dict[job] = dict() maps_manager = MapsManager(launch_dir / job) - split_list = find_splits(maps_manager.maps_path, maps_manager.split_name) + split_list = find_splits(maps_manager.maps_path) split_set = split_set | set(split_list) for split in split_set: performances_dict[job][split] = dict() selection_metrics = find_selection_metrics( maps_manager.maps_path, - maps_manager.split_name, split, ) selection_set = selection_set | set(selection_metrics) for metric in selection_metrics: validation_metrics = get_metrics( maps_manager.maps_path, - maps_manager.split_name, "validation", split, metric, diff --git a/clinicadl/validation/cross_validation.py b/clinicadl/validation/cross_validation.py deleted file mode 100644 index 96fdfe1b9..000000000 --- a/clinicadl/validation/cross_validation.py +++ /dev/null @@ -1,41 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Optional, Tuple - -from pydantic import BaseModel, ConfigDict, field_validator -from pydantic.types import NonNegativeInt - -# from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.splitter.split_utils import find_splits - -logger = getLogger("clinicadl.cross_validation_config") - - -class CrossValidationConfig( - BaseModel -): # TODO : put in data/cross-validation/splitter module - """ - Config class to configure the cross validation procedure. - - tsv_directory is an argument that must be passed by the user. - """ - - n_splits: NonNegativeInt = 0 - split: Optional[Tuple[NonNegativeInt, ...]] = None - tsv_directory: Optional[Path] = None # not needed in predict ? - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @field_validator("split", mode="before") - def validator_split(cls, v): - if isinstance(v, list): - return tuple(v) - return v # TODO : check that split exists (and check coherence with n_splits) - - def adapt_cross_val_with_maps_manager_info( - self, maps_manager - ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future - # TEMPORARY - if not self.split: - self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) - logger.debug(f"List of splits {self.split}") diff --git a/clinicadl/validation/split_manager/__init__.py b/clinicadl/validation/split_manager/__init__.py deleted file mode 100644 index 616807755..000000000 --- a/clinicadl/validation/split_manager/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .kfold import KFoldSplit -from .single_split import SingleSplit diff --git a/clinicadl/validation/split_manager/kfold.py b/clinicadl/validation/split_manager/kfold.py deleted file mode 100644 index a3c26baaa..000000000 --- a/clinicadl/validation/split_manager/kfold.py +++ /dev/null @@ -1,52 +0,0 @@ -from pathlib import Path - -from clinicadl.validation.split_manager.split_manager import SplitManager - - -class KFoldSplit(SplitManager): - def __init__( - self, - caps_directory, - tsv_path, - diagnoses, - n_splits, - baseline=False, - valid_longitudinal=False, - multi_cohort=False, - split_list=None, - ): - super().__init__( - caps_directory, - tsv_path, - diagnoses, - baseline, - valid_longitudinal, - multi_cohort, - split_list, - ) - self.n_splits = n_splits - - def max_length(self) -> int: - return self.n_splits - - def __len__(self): - if not self.split_list: - return self.n_splits - else: - return len(self.split_list) - - @property - def allowed_splits_list(self): - return [i for i in range(self.n_splits)] - - def split_iterator(self): - if not self.split_list: - return range(self.n_splits) - else: - return self.split_list - - def _get_tsv_paths(self, cohort_path: Path, *args): - for split in args: - train_path = cohort_path / f"split-{split}" - valid_path = cohort_path / f"split-{split}" - return train_path, valid_path diff --git a/clinicadl/validation/split_manager/single_split.py b/clinicadl/validation/split_manager/single_split.py deleted file mode 100644 index 6ff282bb2..000000000 --- a/clinicadl/validation/split_manager/single_split.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path - -from clinicadl.validation.split_manager.split_manager import SplitManager - - -class SingleSplit(SplitManager): - def __init__( - self, - caps_directory, - tsv_path, - diagnoses, - baseline=False, - valid_longitudinal=False, - multi_cohort=False, - split_list=None, - ): - super().__init__( - caps_directory, - tsv_path, - diagnoses, - baseline, - valid_longitudinal, - multi_cohort, - split_list, - ) - - def max_length(self) -> int: - return 1 - - def __len__(self): - return 1 - - @property - def allowed_splits_list(self): - return [0] - - def split_iterator(self): - return range(1) - - def _get_tsv_paths(self, cohort_path: Path, *args): - train_path = cohort_path - valid_path = cohort_path - - return train_path, valid_path diff --git a/clinicadl/validation/split_manager/split_manager.py b/clinicadl/validation/split_manager/split_manager.py deleted file mode 100644 index 5696e1571..000000000 --- a/clinicadl/validation/split_manager/split_manager.py +++ /dev/null @@ -1,284 +0,0 @@ -import abc -from logging import getLogger -from pathlib import Path - -import pandas as pd - -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLConfigurationError, - ClinicaDLTSVError, -) -from clinicadl.utils.iotools.clinica_utils import check_caps_folder - -logger = getLogger("clinicadl.split_manager") - - -class SplitManager: - def __init__( - self, - caps_directory: Path, - tsv_path: Path, - diagnoses, - baseline=False, - valid_longitudinal=False, - multi_cohort=False, - split_list=None, - ): - """ - - Parameters - ---------- - caps_director: str (path) - Path to the caps directory - tsv_path: str - Path to the tsv that is going to be split - diagonoses: List[str] - List of diagnosis - baseline: bool - if True, split only on baseline sessions - valid_longitudinal: bool - if True, split validation on longitudinal sessions - multi-cohort: bool - split_list: List[str] - - """ - self._check_tsv_path(tsv_path, multi_cohort) - self.tsv_path = tsv_path - self.caps_dict = self._create_caps_dict(caps_directory, multi_cohort) - self.multi_cohort = multi_cohort - self.diagnoses = diagnoses - self.baseline = baseline - self.valid_longitudinal = valid_longitudinal - self.split_list = split_list - - @abc.abstractmethod - def max_length(self) -> int: - """Maximum number of splits""" - pass - - @abc.abstractmethod - def __len__(self): - pass - - @property - @abc.abstractmethod - def allowed_splits_list(self): - """ - List of possible splits if no restriction was applied - - Returns: - list[int]: list of all possible splits - """ - pass - - def __getitem__(self, item): - """ - Returns a dictionary of DataFrames with train and validation data. - - Args: - item (int): Index of the split wanted. - Returns: - Dict[str:pd.DataFrame]: dictionary with two keys (train and validation). - """ - self._check_item(item) - - if self.multi_cohort: - tsv_df = pd.read_csv(self.tsv_path, sep="\t") - train_df = pd.DataFrame() - valid_df = pd.DataFrame() - found_diagnoses = set() - for idx in range(len(tsv_df)): - cohort_name = tsv_df.loc[idx, "cohort"] - cohort_path = Path(tsv_df.loc[idx, "path"]) - cohort_diagnoses = ( - tsv_df.loc[idx, "diagnoses"].replace(" ", "").split(",") - ) - if bool(set(cohort_diagnoses) & set(self.diagnoses)): - target_diagnoses = list(set(cohort_diagnoses) & set(self.diagnoses)) - - cohort_train_df, cohort_valid_df = self.concatenate_diagnoses( - item, cohort_path=cohort_path, cohort_diagnoses=target_diagnoses - ) - cohort_train_df["cohort"] = cohort_name - cohort_valid_df["cohort"] = cohort_name - train_df = pd.concat([train_df, cohort_train_df]) - valid_df = pd.concat([valid_df, cohort_valid_df]) - found_diagnoses = found_diagnoses | ( - set(cohort_diagnoses) & set(self.diagnoses) - ) - - if found_diagnoses != set(self.diagnoses): - raise ValueError( - f"The diagnoses found in the multi cohort dataset {found_diagnoses} " - f"do not correspond to the diagnoses wanted {set(self.diagnoses)}." - ) - train_df.reset_index(inplace=True, drop=True) - valid_df.reset_index(inplace=True, drop=True) - else: - train_df, valid_df = self.concatenate_diagnoses(item) - train_df["cohort"] = "single" - valid_df["cohort"] = "single" - - return { - "train": train_df, - "validation": valid_df, - } - - def concatenate_diagnoses( - self, split, cohort_path: Path = None, cohort_diagnoses=None - ): - """Concatenated the diagnoses needed to form the train and validation sets.""" - tmp_cohort_path = cohort_path if cohort_path is not None else self.tsv_path - train_path, valid_path = self._get_tsv_paths( - tmp_cohort_path, - split, - ) - logger.debug(f"Training data loaded at {train_path}") - logger.debug(f"Validation data loaded at {valid_path}") - if cohort_diagnoses is None: - cohort_diagnoses = self.diagnoses - if self.baseline: - train_path = train_path / "train_baseline.tsv" - else: - train_path = train_path / "train.tsv" - if self.valid_longitudinal: - valid_path = valid_path / "validation.tsv" - else: - valid_path = valid_path / "validation_baseline.tsv" - - train_df = pd.read_csv(train_path, sep="\t") - valid_df = pd.read_csv(valid_path, sep="\t") - - list_columns = train_df.columns.values - - if ( - "diagnosis" not in list_columns - # or "age" not in list_columns - # or "sex" not in list_columns - ): - parents_path = train_path.resolve().parent - while ( - not (parents_path / "labels.tsv").is_file() - and ((parents_path / "kfold.json").is_file()) - or (parents_path / "split.json").is_file() - ): - parents_path = parents_path.parent - try: - labels_df = pd.read_csv(parents_path / "labels.tsv", sep="\t") - train_df = pd.merge( - train_df, - labels_df, - how="inner", - on=["participant_id", "session_id"], - ) - except Exception: - pass - - list_columns = valid_df.columns.values - if ( - "diagnosis" not in list_columns - # or "age" not in list_columns - # or "sex" not in list_columns - ): - parents_path = valid_path.resolve().parent - while ( - not (parents_path / "labels.tsv").is_file() - and ((parents_path / "kfold.json").is_file()) - or (parents_path / "split.json").is_file() - ): - parents_path = parents_path.parent - try: - labels_df = pd.read_csv(parents_path / "labels.tsv", sep="\t") - valid_df = pd.merge( - valid_df, - labels_df, - how="inner", - on=["participant_id", "session_id"], - ) - except Exception: - pass - train_df = train_df[ - train_df.diagnosis.isin(cohort_diagnoses) - ] # TO MODIFY with train - valid_df = valid_df[valid_df.diagnosis.isin(cohort_diagnoses)] - - train_df.reset_index(inplace=True, drop=True) - valid_df.reset_index(inplace=True, drop=True) - - return train_df, valid_df - - @abc.abstractmethod - def _get_tsv_paths(self, cohort_path, *args): - """ - Computes the paths to the TSV files needed depending on the split structure. - - Args: - cohort_path (str): path to the split structure of a cohort. - split (int): Index of the split. - Returns: - train_path (str): path to the directory containing training data. - valid_path (str): path to the directory containing validation data. - """ - pass - - @abc.abstractmethod - def split_iterator(self): - """Returns an iterable to iterate on all splits wanted.""" - pass - - def _check_item(self, item): - if item not in self.allowed_splits_list: - raise IndexError( - f"Split index {item} out of allowed splits {self.allowed_splits_list}." - ) - - @staticmethod - def _create_caps_dict(caps_directory: Path, multi_cohort): - if multi_cohort: - if not caps_directory.suffix == ".tsv": - raise ClinicaDLArgumentError( - "If multi_cohort is given, the CAPS_DIRECTORY argument should be a path to a TSV file." - ) - else: - caps_df = pd.read_csv(caps_directory, sep="\t") - SplitManager._check_multi_cohort_tsv(caps_df, "CAPS") - caps_dict = dict() - for idx in range(len(caps_df)): - cohort = caps_df.loc[idx, "cohort"] - caps_path = caps_df.loc[idx, "path"] - check_caps_folder(caps_path) - caps_dict[cohort] = caps_path - else: - check_caps_folder(caps_directory) - caps_dict = {"single": caps_directory} - - return caps_dict - - @staticmethod - def _check_tsv_path(tsv_path, multi_cohort): - if multi_cohort: - if tsv_path.suffix != ".tsv": - raise ClinicaDLArgumentError( - "If multi_cohort is given, the TSV_DIRECTORY argument should be a path to a TSV file." - ) - else: - tsv_df = pd.read_csv(tsv_path, sep="\t") - SplitManager._check_multi_cohort_tsv(tsv_df, "labels") - else: - if tsv_path.suffix == ".tsv": - raise ClinicaDLConfigurationError( - f"You gave the path to a TSV file in tsv_path {tsv_path}. " - f"To use multi-cohort framework, please add 'multi_cohort=true' to the configuration file or the --multi_cohort flag." - ) - - @staticmethod - def _check_multi_cohort_tsv(tsv_df, purpose): - if purpose.upper() == "CAPS": - mandatory_col = {"cohort", "path"} - else: - mandatory_col = {"cohort", "path", "diagnoses"} - if not mandatory_col.issubset(tsv_df.columns.values): - raise ClinicaDLTSVError( - f"Columns of the TSV file used for {purpose} location must include {mandatory_col}." - ) diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py index 165b36dd0..2f8c8a30a 100644 --- a/clinicadl/validator/config.py +++ b/clinicadl/validator/config.py @@ -18,7 +18,6 @@ class ValidatorConfig(BaseModel): maps_path: Path mode: str network_task: str - split_name: Optional[str] = None num_networks: Optional[int] = None fsdp: Optional[bool] = None amp: Optional[bool] = None diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py index d55810299..c8f5e9451 100644 --- a/clinicadl/validator/validator.py +++ b/clinicadl/validator/validator.py @@ -249,7 +249,7 @@ def _test_loader( if cluster.master: log_dir = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group ) @@ -339,7 +339,7 @@ def _test_loader_ssda( for selection_metric in selection_metrics: log_dir = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group ) @@ -430,7 +430,7 @@ def _compute_output_tensors( tensor_path = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group / "tensors" @@ -475,9 +475,7 @@ def _ensemble_prediction( """Computes the results on the image-level.""" if not selection_metrics: - selection_metrics = find_selection_metrics( - maps_manager.maps_path, maps_manager.split_name, split - ) + selection_metrics = find_selection_metrics(maps_manager.maps_path, split) for selection_metric in selection_metrics: ##################### diff --git a/tests/test_predict.py b/tests/test_predict.py index 2c26a4a3e..849f0e20d 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -119,14 +119,12 @@ def test_predict(cmdopt, tmp_path, test_name): for mode in modes: get_prediction( predict_manager.maps_manager.maps_path, - predict_manager.maps_manager.split_name, data_group="test-RANDOM", mode=mode, ) if use_labels: get_metrics( predict_manager.maps_manager.maps_path, - predict_manager.maps_manager.split_name, data_group="test-RANDOM", mode=mode, ) diff --git a/tests/test_resume.py b/tests/test_resume.py index 9fde97a45..1598267d8 100644 --- a/tests/test_resume.py +++ b/tests/test_resume.py @@ -7,6 +7,8 @@ import pytest from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from .testing_tools import modify_maps @@ -48,7 +50,9 @@ def test_resume(cmdopt, tmp_path, test_name): assert flag_error maps_manager = MapsManager(maps_stopped) - split_manager = maps_manager._init_split_manager() + splitter_config = SplitterConfig(**maps_manager.parameters) + split_manager = Splitter(splitter_config) + for split in split_manager.split_iterator(): performances_flag = ( maps_stopped / f"split-{split}" / "best-loss" / "train" diff --git a/tests/unittests/random_search/test_random_search_config.py b/tests/unittests/random_search/test_random_search_config.py index 87c89dadf..f2e195309 100644 --- a/tests/unittests/random_search/test_random_search_config.py +++ b/tests/unittests/random_search/test_random_search_config.py @@ -32,7 +32,7 @@ def dummy_arguments(caps_example): args = { "caps_directory": caps_example, "preprocessing_json": "preprocessing.json", - "tsv_directory": "", + "tsv_path": "", "maps_dir": "", "gpu": False, } diff --git a/tests/unittests/train/tasks/classification/test_classification_config.py b/tests/unittests/train/tasks/classification/test_classification_config.py index 2bfbb9fa6..71e853872 100644 --- a/tests/unittests/train/tasks/classification/test_classification_config.py +++ b/tests/unittests/train/tasks/classification/test_classification_config.py @@ -35,7 +35,7 @@ def dummy_arguments(caps_example): args = { "caps_directory": caps_example, "preprocessing_json": "preprocessing.json", - "tsv_directory": "", + "tsv_path": "", "maps_dir": "", "gpu": False, } diff --git a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py index 1ff6283c3..f013386c1 100644 --- a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py +++ b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py @@ -24,7 +24,7 @@ def dummy_arguments(caps_example): args = { "caps_directory": caps_example, "preprocessing_json": "preprocessing.json", - "tsv_directory": "", + "tsv_path": "", "maps_dir": "", "gpu": False, } diff --git a/tests/unittests/train/tasks/regression/test_regression_config.py b/tests/unittests/train/tasks/regression/test_regression_config.py index 3085755b0..4b01e1084 100644 --- a/tests/unittests/train/tasks/regression/test_regression_config.py +++ b/tests/unittests/train/tasks/regression/test_regression_config.py @@ -24,7 +24,7 @@ def dummy_arguments(caps_example): args = { "caps_directory": caps_example, "preprocessing_json": "preprocessing.json", - "tsv_directory": "", + "tsv_path": "", "maps_dir": "", "gpu": False, } diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 158a6d6c2..503b88ddf 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -7,9 +7,9 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config.ssda import SSDAConfig from clinicadl.network.config import NetworkConfig +from clinicadl.splitter.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig -from clinicadl.validation.cross_validation import CrossValidationConfig # Tests for customed validators # @@ -19,12 +19,12 @@ def caps_example(): return dir_ -def test_cross_validation_config(): - c = CrossValidationConfig( - split=[0], - tsv_directory="", - ) - assert c.split == (0,) +# def test_cross_validation_config(): +# c = ValidationConfig( +# split=[0], +# tsv_path="", +# ) +# assert c.split == (0,) def test_data_config(caps_example): @@ -113,7 +113,7 @@ def dummy_arguments(caps_example): args = { "caps_directory": caps_example, "preprocessing_json": "preprocessing.json", - "tsv_directory": "", + "tsv_path": "", "maps_dir": "", "gpu": False, "architecture": "", @@ -181,14 +181,14 @@ def test_fails_validations(bad_inputs, training_config): def test_passes_validations(good_inputs, training_config): c = training_config(**good_inputs) assert not c.computational.gpu - assert c.cross_validation.n_splits == 7 + assert c.split.n_splits == 7 assert c.optimizer.optimizer == "Adagrad" assert c.transforms.data_augmentation == ("Smoothing",) assert c.data.diagnoses == ("AD",) assert c.dataloader.batch_size == 1 assert c.transforms.size_reduction_factor == 5 assert c.optimizer.learning_rate == 1e-1 - assert c.cross_validation.split == (0,) + assert c.split.split == (0,) assert c.early_stopping.tolerance == 0.0