From 61c00d28de734cd0a34f047305fec7887ff8c3b1 Mon Sep 17 00:00:00 2001 From: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:15:34 +0200 Subject: [PATCH] Trainer config (#561) * get trainer out of mapsmanager folder * base training config class * task specific config classes * unit test for config classes * changes in cli to have default values from task config objects * ranem and simplify build_train_dict * unit test for train_utils * small modification in training config toml template * rename build_train_dict in the other parts of the project * modify task_launcher to use config objects * Bump sqlparse from 0.4.4 to 0.5.0 (#558) Bumps [sqlparse](https://github.com/andialbrecht/sqlparse) from 0.4.4 to 0.5.0. - [Changelog](https://github.com/andialbrecht/sqlparse/blob/master/CHANGELOG) - [Commits](https://github.com/andialbrecht/sqlparse/compare/0.4.4...0.5.0) --- updated-dependencies: - dependency-name: sqlparse dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * typo * change _network_task attribute * omissions * patches to match CLI data * small modifications * correction * correction reconstruction default loss * add architecture command specific to the task * add use_extracted_features parameter * add VAE parameters in reconstruction * add condition on whether cli arg is default or from user * correct wrong import in resume * validators on assignment * reformat * replace literal with enum * review on CLI options * convert enum to str for train function * correct track exp issue * test for ci --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../random_search/random_search_utils.py | 4 +- clinicadl/resources/config/train_config.toml | 2 +- clinicadl/train/resume.py | 2 +- clinicadl/train/tasks/base_training_config.py | 194 ++++++++++ clinicadl/train/tasks/classification_cli.py | 33 +- .../train/tasks/classification_config.py | 55 +++ clinicadl/train/tasks/reconstruction_cli.py | 22 +- .../train/tasks/reconstruction_config.py | 69 ++++ clinicadl/train/tasks/regression_cli.py | 24 +- clinicadl/train/tasks/regression_config.py | 50 +++ clinicadl/train/tasks/task_utils.py | 160 ++++----- clinicadl/train/train.py | 2 +- clinicadl/train/train_utils.py | 120 +++---- clinicadl/utils/cli_param/train_option.py | 338 ++++++++++-------- clinicadl/utils/maps_manager/maps_manager.py | 6 + .../{maps_manager => }/trainer/__init__.py | 0 .../{maps_manager => }/trainer/trainer.py | 2 +- clinicadl/utils/trainer/trainer_utils.py | 0 clinicadl/utils/trainer/training_config.py | 0 .../train/tasks/test_base_training_config.py | 80 +++++ .../train/tasks/test_classification_config.py | 62 ++++ .../train/tasks/test_reconstruction_config.py | 62 ++++ .../train/tasks/test_regression_config.py | 54 +++ tests/unittests/train/test_train_utils.py | 206 +++++++++++ 24 files changed, 1216 insertions(+), 331 deletions(-) create mode 100644 clinicadl/train/tasks/base_training_config.py create mode 100644 clinicadl/train/tasks/classification_config.py create mode 100644 clinicadl/train/tasks/reconstruction_config.py create mode 100644 clinicadl/train/tasks/regression_config.py rename clinicadl/utils/{maps_manager => }/trainer/__init__.py (100%) rename clinicadl/utils/{maps_manager => }/trainer/trainer.py (99%) create mode 100644 clinicadl/utils/trainer/trainer_utils.py create mode 100644 clinicadl/utils/trainer/training_config.py create mode 100644 tests/unittests/train/tasks/test_base_training_config.py create mode 100644 tests/unittests/train/tasks/test_classification_config.py create mode 100644 tests/unittests/train/tasks/test_reconstruction_config.py create mode 100644 tests/unittests/train/tasks/test_regression_config.py create mode 100644 tests/unittests/train/test_train_utils.py diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index 1d878913c..ea8337c86 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -4,7 +4,7 @@ import toml -from clinicadl.train.train_utils import build_train_dict +from clinicadl.train.train_utils import extract_config_from_toml_file from clinicadl.utils.exceptions import ClinicaDLConfigurationError from clinicadl.utils.preprocessing import path_decoder, read_preprocessing @@ -49,7 +49,7 @@ def get_space_dict(launch_directory: Path) -> Dict[str, Any]: space_dict.setdefault("n_conv", 1) space_dict.setdefault("wd_bool", True) - train_default = build_train_dict(toml_path, space_dict["network_task"]) + train_default = extract_config_from_toml_file(toml_path, space_dict["network_task"]) # Mode and preprocessing preprocessing_json = ( diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index e850a14d9..f4f2afe30 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -64,7 +64,7 @@ diagnoses = ["AD", "CN"] baseline = false valid_longitudinal = false normalize = true -data_augmentation = false +data_augmentation = [] sampler = "random" size_reduction=false size_reduction_factor=2 diff --git a/clinicadl/train/resume.py b/clinicadl/train/resume.py index bfa9a16b7..af2a806c1 100644 --- a/clinicadl/train/resume.py +++ b/clinicadl/train/resume.py @@ -7,7 +7,7 @@ from pathlib import Path from clinicadl import MapsManager -from clinicadl.utils.maps_manager.trainer import Trainer +from clinicadl.utils.trainer import Trainer def replace_arg(options, key_name, value): diff --git a/clinicadl/train/tasks/base_training_config.py b/clinicadl/train/tasks/base_training_config.py new file mode 100644 index 000000000..522aaf1a5 --- /dev/null +++ b/clinicadl/train/tasks/base_training_config.py @@ -0,0 +1,194 @@ +from enum import Enum +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, PrivateAttr, field_validator + +logger = getLogger("clinicadl.base_training_config") + + +class Compensation(str, Enum): + """Available compensations in clinicaDL.""" + + MEMORY = "memory" + TIME = "time" + + +class SizeReductionFactor(int, Enum): + """Available size reduction factors in ClinicaDL.""" + + TWO = 2 + THREE = 3 + FOUR = 4 + FIVE = 5 + + +class ExperimentTracking(str, Enum): + """Available tools for experiment tracking in ClinicaDL.""" + + MLFLOW = "mlflow" + WANDB = "wandb" + + +class Sampler(str, Enum): + """Available samplers in ClinicaDL.""" + + RANDOM = "random" + WEIGHTED = "weighted" + + +class Mode(str, Enum): + """Available modes in ClinicaDL.""" + + IMAGE = "image" + PATCH = "patch" + ROI = "roi" + SLICE = "slice" + + +class BaseTaskConfig(BaseModel): + """ + Base class to handle parameters of the training pipeline. + """ + + caps_directory: Path + preprocessing_json: Path + tsv_directory: Path + output_maps_directory: Path + # Computational + gpu: bool = True + n_proc: int = 2 + batch_size: int = 8 + evaluation_steps: int = 0 + fully_sharded_data_parallel: bool = False + amp: bool = False + # Reproducibility + seed: int = 0 + deterministic: bool = False + compensation: Compensation = Compensation.MEMORY + save_all_models: bool = False + track_exp: Optional[ExperimentTracking] = None + # Model + multi_network: bool = False + ssda_network: bool = False + # Data + multi_cohort: bool = False + diagnoses: Tuple[str, ...] = ("AD", "CN") + baseline: bool = False + valid_longitudinal: bool = False + normalize: bool = True + data_augmentation: Tuple[str, ...] = () + sampler: Sampler = Sampler.RANDOM + size_reduction: bool = False + size_reduction_factor: SizeReductionFactor = ( + SizeReductionFactor.TWO + ) # TODO : change to optional and remove size_reduction parameter + caps_target: Path = Path("") + tsv_target_lab: Path = Path("") + tsv_target_unlab: Path = Path("") + preprocessing_dict_target: Path = Path( + "" + ) ## TODO : change name in commandline. preprocessing_json_target? + # Cross validation + n_splits: int = 0 + split: Tuple[int, ...] = () + # Optimization + optimizer: str = "Adam" + epochs: int = 20 + learning_rate: float = 1e-4 + adaptive_learning_rate: bool = False + weight_decay: float = 1e-4 + dropout: float = 0.0 + patience: int = 0 + tolerance: float = 0.0 + accumulation_steps: int = 1 + profiler: bool = False + # Transfer Learning + transfer_path: Optional[Path] = None + transfer_selection_metric: str = "loss" + nb_unfrozen_layer: int = 0 + # Information + emissions_calculator: bool = False + # Mode + use_extracted_features: bool = False # unused. TODO : remove + # Private + _preprocessing_dict: Dict[str, Any] = PrivateAttr() + _preprocessing_dict_target: Dict[str, Any] = PrivateAttr() + _mode: Mode = PrivateAttr() + + class ConfigDict: + validate_assignment = True + + @field_validator("diagnoses", "split", "data_augmentation", mode="before") + def list_to_tuples(cls, v): + if isinstance(v, list): + return tuple(v) + return v + + @field_validator("transfer_path", mode="before") + def false_to_none(cls, v): + if v is False: + return None + return v + + @classmethod + def get_available_optimizers(cls) -> List[str]: + """To get the list of available optimizers.""" + available_optimizers = [ # TODO : connect to PyTorch to have available optimizers + "Adadelta", + "Adagrad", + "Adam", + "AdamW", + "Adamax", + "ASGD", + "NAdam", + "RAdam", + "RMSprop", + "SGD", + ] + return available_optimizers + + @field_validator("optimizer") + def validator_optimizer(cls, v): + available_optimizers = cls.get_available_optimizers() + assert ( + v in available_optimizers + ), f"Optimizer '{v}' not supported. Please choose among: {available_optimizers}" + return v + + @classmethod + def get_available_transforms(cls) -> List[str]: + """To get the list of available transforms.""" + available_transforms = [ # TODO : connect to transforms module + "Noise", + "Erasing", + "CropPad", + "Smoothing", + "Motion", + "Ghosting", + "Spike", + "BiasField", + "RandomBlur", + "RandomSwap", + ] + return available_transforms + + @field_validator("data_augmentation", mode="before") + def validator_data_augmentation(cls, v): + if v is False: + return () + + available_transforms = cls.get_available_transforms() + for transform in v: + assert ( + transform in available_transforms + ), f"Transform '{transform}' not supported. Please pick among: {available_transforms}" + return v + + @field_validator("dropout") + def validator_dropout(cls, v): + assert ( + 0 <= v <= 1 + ), f"dropout must be between 0 and 1 but it has been set to {v}." + return v diff --git a/clinicadl/train/tasks/classification_cli.py b/clinicadl/train/tasks/classification_cli.py index 2f470fd02..b345b3ee7 100644 --- a/clinicadl/train/tasks/classification_cli.py +++ b/clinicadl/train/tasks/classification_cli.py @@ -1,7 +1,12 @@ +from pathlib import Path + import click +from click.core import ParameterSource +from clinicadl.train.train_utils import extract_config_from_toml_file from clinicadl.utils.cli_param import train_option +from .classification_config import ClassificationConfig from .task_utils import task_launcher @@ -26,7 +31,7 @@ @train_option.compensation @train_option.save_all_models # Model -@train_option.architecture +@train_option.classification_architecture @train_option.multi_network @train_option.ssda_network # Data @@ -61,8 +66,8 @@ @train_option.transfer_selection_metric @train_option.nb_unfrozen_layer # Task-related -@train_option.label -@train_option.selection_metrics +@train_option.classification_label +@train_option.classification_selection_metrics @train_option.selection_threshold @train_option.classification_loss # information @@ -84,14 +89,14 @@ def cli(**kwargs): configuration file in TOML format. For more details, please visit the documentation: https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file """ - task_specific_options = [ - "label", - "selection_metrics", - "selection_threshold", - "loss", - ] - task_launcher("classification", task_specific_options, **kwargs) - - -if __name__ == "__main__": - cli() + options = {} + if kwargs["config_file"]: + options = extract_config_from_toml_file( + Path(kwargs["config_file"]), + "classification", + ) + for arg in kwargs: + if click.get_current_context().get_parameter_source(arg) == ParameterSource.COMMANDLINE: + options[arg] = kwargs[arg] + config = ClassificationConfig(**options) + task_launcher(config) diff --git a/clinicadl/train/tasks/classification_config.py b/clinicadl/train/tasks/classification_config.py new file mode 100644 index 000000000..f62123fdb --- /dev/null +++ b/clinicadl/train/tasks/classification_config.py @@ -0,0 +1,55 @@ +from logging import getLogger +from typing import Dict, List, Tuple + +from pydantic import PrivateAttr, field_validator + +from .base_training_config import BaseTaskConfig + +logger = getLogger("clinicadl.classification_config") + + +class ClassificationConfig(BaseTaskConfig): + """Config class to handle parameters of the classification task.""" + + architecture: str = "Conv5_FC3" + loss: str = "CrossEntropyLoss" + label: str = "diagnosis" + label_code: Dict[str, int] = {} + selection_threshold: float = 0.0 + selection_metrics: Tuple[str, ...] = ("loss",) + # private + _network_task: str = PrivateAttr(default="classification") + + @field_validator("selection_metrics", mode="before") + def list_to_tuples(cls, v): + if isinstance(v, list): + return tuple(v) + return v + + @classmethod + def get_compatible_losses(cls) -> List[str]: + """To get the list of losses implemented and compatible with this task.""" + compatible_losses = [ # TODO : connect to the Loss module + "CrossEntropyLoss", + "MultiMarginLoss", + ] + return compatible_losses + + @field_validator("loss") + def validator_loss(cls, v): + compatible_losses = cls.get_compatible_losses() + assert ( + v in compatible_losses + ), f"Loss '{v}' can't be used for this task. Please choose among: {compatible_losses}" + return v + + @field_validator("selection_threshold") + def validator_threshold(cls, v): + assert ( + 0 <= v <= 1 + ), f"selection_threshold must be between 0 and 1 but it has been set to {v}." + return v + + @field_validator("architecture") + def validator_architecture(cls, v): + return v # TODO : connect to network module to have list of available architectures diff --git a/clinicadl/train/tasks/reconstruction_cli.py b/clinicadl/train/tasks/reconstruction_cli.py index 95816a116..4bce83e04 100644 --- a/clinicadl/train/tasks/reconstruction_cli.py +++ b/clinicadl/train/tasks/reconstruction_cli.py @@ -1,7 +1,12 @@ +from pathlib import Path + import click +from click.core import ParameterSource +from clinicadl.train.train_utils import extract_config_from_toml_file from clinicadl.utils.cli_param import train_option +from .reconstruction_config import ReconstructionConfig from .task_utils import task_launcher @@ -26,7 +31,7 @@ @train_option.compensation @train_option.save_all_models # Model -@train_option.architecture +@train_option.reconstruction_architecture @train_option.multi_network @train_option.ssda_network # Data @@ -61,7 +66,7 @@ @train_option.transfer_selection_metric @train_option.nb_unfrozen_layer # Task-related -@train_option.selection_metrics +@train_option.reconstruction_selection_metrics @train_option.reconstruction_loss # information @train_option.emissions_calculator @@ -82,5 +87,14 @@ def cli(**kwargs): configuration file in TOML format. For more details, please visit the documentation: https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file """ - task_specific_options = ["selection_metrics", "loss"] - task_launcher("reconstruction", task_specific_options, **kwargs) + options = {} + if kwargs["config_file"]: + options = extract_config_from_toml_file( + Path(kwargs["config_file"]), + "reconstruction", + ) + for arg in kwargs: + if click.get_current_context().get_parameter_source(arg) == ParameterSource.COMMANDLINE: + options[arg] = kwargs[arg] + config = ReconstructionConfig(**options) + task_launcher(config) diff --git a/clinicadl/train/tasks/reconstruction_config.py b/clinicadl/train/tasks/reconstruction_config.py new file mode 100644 index 000000000..6442a59f5 --- /dev/null +++ b/clinicadl/train/tasks/reconstruction_config.py @@ -0,0 +1,69 @@ +from enum import Enum +from logging import getLogger +from typing import List, Tuple + +from pydantic import PrivateAttr, field_validator + +from .base_training_config import BaseTaskConfig + +logger = getLogger("clinicadl.reconstruction_config") + + +class Normalization(str, Enum): + """Available normalization layers in ClinicaDL.""" + + BATCH = "batch" + GROUP = "group" + INSTANCE = "instance" + + +class ReconstructionConfig(BaseTaskConfig): + """Config class to handle parameters of the reconstruction task.""" + + loss: str = "MSELoss" + selection_metrics: Tuple[str, ...] = ("loss",) + # model + architecture: str = "AE_Conv5_FC3" + latent_space_size: int = 128 + feature_size: int = 1024 + n_conv: int = 4 + io_layer_channels: int = 8 + recons_weight: int = 1 + kl_weight: int = 1 + normalization: Normalization = Normalization.BATCH + # private + _network_task: str = PrivateAttr(default="reconstruction") + + @field_validator("selection_metrics", mode="before") + def list_to_tuples(cls, v): + if isinstance(v, list): + return tuple(v) + return v + + @classmethod + def get_compatible_losses(cls) -> List[str]: + """To get the list of losses implemented and compatible with this task.""" + compatible_losses = [ # TODO : connect to the Loss module + "L1Loss", + "MSELoss", + "KLDivLoss", + "BCEWithLogitsLoss", + "HuberLoss", + "SmoothL1Loss", + "VAEGaussianLoss", + "VAEBernoulliLoss", + "VAEContinuousBernoulliLoss", + ] + return compatible_losses + + @field_validator("loss") + def validator_loss(cls, v): + compatible_losses = cls.get_compatible_losses() + assert ( + v in compatible_losses + ), f"Loss '{v}' can't be used for this task. Please choose among: {compatible_losses}" + return v + + @field_validator("architecture") + def validator_architecture(cls, v): + return v # TODO : connect to network module to have list of available architectures diff --git a/clinicadl/train/tasks/regression_cli.py b/clinicadl/train/tasks/regression_cli.py index 2533db406..c1ede7b1b 100644 --- a/clinicadl/train/tasks/regression_cli.py +++ b/clinicadl/train/tasks/regression_cli.py @@ -1,7 +1,12 @@ +from pathlib import Path + import click +from click.core import ParameterSource +from clinicadl.train.train_utils import extract_config_from_toml_file from clinicadl.utils.cli_param import train_option +from .regression_config import RegressionConfig from .task_utils import task_launcher @@ -26,7 +31,7 @@ @train_option.compensation @train_option.save_all_models # Model -@train_option.architecture +@train_option.regression_architecture @train_option.multi_network @train_option.ssda_network # Data @@ -61,8 +66,8 @@ @train_option.transfer_selection_metric @train_option.nb_unfrozen_layer # Task-related -@train_option.label -@train_option.selection_metrics +@train_option.regression_label +@train_option.regression_selection_metrics @train_option.regression_loss # information @train_option.emissions_calculator @@ -83,5 +88,14 @@ def cli(**kwargs): configuration file in TOML format. For more details, please visit the documentation: https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file """ - task_specific_options = ["label", "selection_metrics", "loss"] - task_launcher("regression", task_specific_options, **kwargs) + options = {} + if kwargs["config_file"]: + options = extract_config_from_toml_file( + Path(kwargs["config_file"]), + "regression", + ) + for arg in kwargs: + if click.get_current_context().get_parameter_source(arg) == ParameterSource.COMMANDLINE: + options[arg] = kwargs[arg] + config = RegressionConfig(**options) + task_launcher(config) diff --git a/clinicadl/train/tasks/regression_config.py b/clinicadl/train/tasks/regression_config.py new file mode 100644 index 000000000..3002372a9 --- /dev/null +++ b/clinicadl/train/tasks/regression_config.py @@ -0,0 +1,50 @@ +from logging import getLogger +from typing import List, Tuple + +from pydantic import PrivateAttr, field_validator + +from .base_training_config import BaseTaskConfig + +logger = getLogger("clinicadl.regression_config") + + +class RegressionConfig(BaseTaskConfig): + """Config class to handle parameters of the regression task.""" + + architecture: str = "Conv5_FC3" + loss: str = "MSELoss" + label: str = "age" + selection_metrics: Tuple[str, ...] = ("loss",) + # private + _network_task: str = PrivateAttr(default="regression") + + @field_validator("selection_metrics", mode="before") + def list_to_tuples(cls, v): + if isinstance(v, list): + return tuple(v) + return v + + @classmethod + def get_compatible_losses(cls) -> List[str]: + """To get the list of losses implemented and compatible with this task.""" + compatible_losses = [ # TODO : connect to the Loss module + "L1Loss", + "MSELoss", + "KLDivLoss", + "BCEWithLogitsLoss", + "HuberLoss", + "SmoothL1Loss", + ] + return compatible_losses + + @field_validator("loss") + def validator_loss(cls, v): + compatible_losses = cls.get_compatible_losses() + assert ( + v in compatible_losses + ), f"Loss '{v}' can't be used for this task. Please choose among: {compatible_losses}" + return v + + @field_validator("architecture") + def validator_architecture(cls, v): + return v # TODO : connect to network module to have list of available architectures diff --git a/clinicadl/train/tasks/task_utils.py b/clinicadl/train/tasks/task_utils.py index 0331468df..1c806d14d 100644 --- a/clinicadl/train/tasks/task_utils.py +++ b/clinicadl/train/tasks/task_utils.py @@ -1,107 +1,51 @@ from logging import getLogger -from typing import List +from clinicadl.train.train import train from clinicadl.utils.caps_dataset.data import CapsDataset from clinicadl.utils.preprocessing import read_preprocessing +from .base_training_config import BaseTaskConfig -def task_launcher(network_task: str, task_options_list: List[str], **kwargs): - """ - Common training framework for all tasks - - Args: - network_task: task learnt by the network. - task_options_list: list of options specific to the task. - kwargs: other arguments and options for network training. - """ - from pathlib import Path - - from clinicadl.train.train import train - from clinicadl.train.train_utils import build_train_dict +logger = getLogger("clinicadl.task_manager") - logger = getLogger("clinicadl.task_manager") - config_file_name = None - if kwargs["config_file"]: - config_file_name = Path(kwargs["config_file"]) - train_dict = build_train_dict(config_file_name, network_task) +def task_launcher(config: BaseTaskConfig) -> None: + """ + Common training framework for all tasks. - # Add arguments - train_dict["network_task"] = network_task - train_dict["caps_directory"] = Path(kwargs["caps_directory"]) - train_dict["tsv_path"] = Path(kwargs["tsv_directory"]) + Adds private attributes to the Config object and launches training. - # Change value in train dict depending on user provided options - standard_options_list = [ - "accumulation_steps", - "adaptive_learning_rate", - "amp", - "architecture", - "baseline", - "batch_size", - "compensation", - "data_augmentation", - "deterministic", - "diagnoses", - "dropout", - "epochs", - "evaluation_steps", - "fully_sharded_data_parallel", - "gpu", - "learning_rate", - "multi_cohort", - "multi_network", - "ssda_network", - "n_proc", - "n_splits", - "nb_unfrozen_layer", - "normalize", - "optimizer", - "patience", - "profiler", - "tolerance", - "track_exp", - "transfer_path", - "transfer_selection_metric", - "valid_longitudinal", - "weight_decay", - "sampler", - "save_all_models", - "seed", - "split", - "caps_target", - "tsv_target_lab", - "tsv_target_unlab", - "preprocessing_dict_target", - ] - all_options_list = standard_options_list + task_options_list + Parameters + ---------- + config : BaseTaskConfig + Configuration object with all the parameters. - for option in all_options_list: - if (kwargs[option] is not None and not isinstance(kwargs[option], tuple)) or ( - isinstance(kwargs[option], tuple) and len(kwargs[option]) != 0 - ): - train_dict[option] = kwargs[option] - if not train_dict["multi_cohort"]: + Raises + ------ + ValueError + If the parameter doesn't match any existing file. + ValueError + If the parameter doesn't match any existing file. + """ + if not config.multi_cohort: preprocessing_json = ( - train_dict["caps_directory"] - / "tensor_extraction" - / kwargs["preprocessing_json"] + config.caps_directory / "tensor_extraction" / config.preprocessing_json ) - if train_dict["ssda_network"]: + if config.ssda_network: preprocessing_json_target = ( - Path(kwargs["caps_target"]) + config.caps_target / "tensor_extraction" - / kwargs["preprocessing_dict_target"] + / config.preprocessing_dict_target ) else: caps_dict = CapsDataset.create_caps_dict( - train_dict["caps_directory"], train_dict["multi_cohort"] + config.caps_directory, config.multi_cohort ) json_found = False for caps_name, caps_path in caps_dict.items(): preprocessing_json = ( - caps_path / "tensor_extraction" / kwargs["preprocessing_json"] + caps_path / "tensor_extraction" / config.preprocessing_json ) if preprocessing_json.is_file(): logger.info( @@ -110,14 +54,14 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): json_found = True if not json_found: raise ValueError( - f"Preprocessing JSON {kwargs['preprocessing_json']} was not found for any CAPS " + f"Preprocessing JSON {config.preprocessing_json} was not found for any CAPS " f"in {caps_dict}." ) # To CHECK AND CHANGE - if train_dict["ssda_network"]: - caps_target = Path(kwargs["caps_target"]) + if config.ssda_network: + caps_target = config.caps_target preprocessing_json_target = ( - caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"] + caps_target / "tensor_extraction" / config.preprocessing_dict_target ) if preprocessing_json_target.is_file(): @@ -127,24 +71,54 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): json_found = True if not json_found: raise ValueError( - f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS " + f"Preprocessing JSON {preprocessing_json_target} was not found for any CAPS " f"in {caps_target}." ) # Mode and preprocessing preprocessing_dict = read_preprocessing(preprocessing_json) - train_dict["preprocessing_dict"] = preprocessing_dict - train_dict["mode"] = preprocessing_dict["mode"] + config._preprocessing_dict = preprocessing_dict + config._mode = preprocessing_dict["mode"] - if train_dict["ssda_network"]: - preprocessing_dict_target = read_preprocessing(preprocessing_json_target) - train_dict["preprocessing_dict_target"] = preprocessing_dict_target + if config.ssda_network: + config._preprocessing_dict_target = read_preprocessing( + preprocessing_json_target + ) # Add default values if missing if ( preprocessing_dict["mode"] == "roi" and "roi_background_value" not in preprocessing_dict ): - preprocessing_dict["roi_background_value"] = 0 + config._preprocessing_dict["roi_background_value"] = 0 + + # temporary # TODO : change train function to give it a config object + maps_dir = config.output_maps_directory + train_dict = config.model_dump( + exclude=["output_maps_directory", "preprocessing_json", "tsv_directory"] + ) + train_dict["tsv_path"] = config.tsv_directory + train_dict[ + "preprocessing_dict" + ] = config._preprocessing_dict # private attributes are not dumped + train_dict["mode"] = config._mode + if config.ssda_network: + train_dict["preprocessing_dict_target"] = config._preprocessing_dict_target + train_dict["network_task"] = config._network_task + if train_dict["transfer_path"] is None: + train_dict["transfer_path"] = False + if train_dict["data_augmentation"] == (): + train_dict["data_augmentation"] = False + split_list = train_dict.pop("split") + train_dict["compensation"] = config.compensation.value + train_dict["size_reduction_factor"] = config.size_reduction_factor.value + if train_dict["track_exp"]: + train_dict["track_exp"] = config.track_exp.value + else: + train_dict["track_exp"] = "" + train_dict["sampler"] = config.sampler.value + if train_dict["network_task"] == "reconstruction": + train_dict["normalization"] = config.normalization.value + ############# - train(Path(kwargs["output_maps_directory"]), train_dict, train_dict.pop("split")) + train(maps_dir, train_dict, split_list) diff --git a/clinicadl/train/train.py b/clinicadl/train/train.py index c25ccc649..0296eb50c 100644 --- a/clinicadl/train/train.py +++ b/clinicadl/train/train.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List from clinicadl import MapsManager -from clinicadl.utils.maps_manager.trainer import Trainer +from clinicadl.utils.trainer import Trainer def train( diff --git a/clinicadl/train/train_utils.py b/clinicadl/train/train_utils.py index cc4b63fb9..cf3c7bd62 100644 --- a/clinicadl/train/train_utils.py +++ b/clinicadl/train/train_utils.py @@ -14,72 +14,70 @@ from clinicadl.utils.preprocessing import path_decoder -def build_train_dict(config_file: Path, task: str) -> Dict[str, Any]: +def extract_config_from_toml_file(config_file: Path, task: str) -> Dict[str, Any]: """ Read the configuration file given by the user. - If it is a TOML file, ensures that the format corresponds to the one in resources. - Args: - config_file: path to a configuration file (JSON of TOML). - task: task learnt by the network (example: classification, regression, reconstruction...). - Returns: - dictionary of values ready to use for the MapsManager - """ - if config_file is None: - # read default values - clinicadl_root_dir = Path(__file__).parents[1] - config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" - config_dict = toml.load(config_path) - config_dict = remove_unused_tasks(config_dict, task) - config_dict = path_decoder(config_dict) - train_dict = dict() - # Fill train_dict from TOML files arguments - for config_section in config_dict: - for key in config_dict[config_section]: - train_dict[key] = config_dict[config_section][key] - - elif config_file.suffix == ".toml": - user_dict = toml.load(config_file) - if "Random_Search" in user_dict: - del user_dict["Random_Search"] - - # read default values - clinicadl_root_dir = Path(__file__).parents[1] - config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" - config_dict = toml.load(config_path) - # Check that TOML file has the same format as the one in clinicadl/resources/config/train_config.toml - if user_dict is not None: - user_dict = path_decoder(user_dict) - for section_name in user_dict: - if section_name not in config_dict: - raise ClinicaDLConfigurationError( - f"{section_name} section is not valid in TOML configuration file. " - f"Please see the documentation to see the list of option in TOML configuration file." - ) - for key in user_dict[section_name]: - if key not in config_dict[section_name]: - raise ClinicaDLConfigurationError( - f"{key} option in {section_name} is not valid in TOML configuration file. " - f"Please see the documentation to see the list of option in TOML configuration file." - ) - config_dict[section_name][key] = user_dict[section_name][key] - - train_dict = dict() - - # task dependent - config_dict = remove_unused_tasks(config_dict, task) - - # Fill train_dict from TOML files arguments - for config_section in config_dict: - for key in config_dict[config_section]: - train_dict[key] = config_dict[config_section][key] - - elif config_file.suffix == ".json": - train_dict = read_json(config_file) - else: + Ensures that the format corresponds to the TOML file template. + + Parameters + ---------- + config_file : Path + Path to a configuration file (JSON of TOML). + task : str + Task performed by the network (e.g. classification). + + Returns + ------- + Dict[str, Any] + Config dictionary with the training parameters extracted from the config file. + + Raises + ------ + ClinicaDLConfigurationError + If configuration file is not a TOML file. + ClinicaDLConfigurationError + If a section in the TOML file is not valid. + ClinicaDLConfigurationError + If an option in the TOML file is not valid. + """ + if config_file.suffix != ".toml": raise ClinicaDLConfigurationError( - f"config_file {config_file} should be a TOML or a JSON file." + f"Config file {config_file} should be a TOML file." ) + + user_dict = toml.load(config_file) + if "Random_Search" in user_dict: + del user_dict["Random_Search"] + + # get the template + clinicadl_root_dir = Path(__file__).parents[1] + config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" + config_dict = toml.load(config_path) + # Check that TOML file has the same format as the one in clinicadl/resources/config/train_config.toml + user_dict = path_decoder(user_dict) + for section_name in user_dict: + if section_name not in config_dict: + raise ClinicaDLConfigurationError( + f"{section_name} section is not valid in TOML configuration file. " + f"Please see the documentation to see the list of option in TOML configuration file." + ) + for key in user_dict[section_name]: + if key not in config_dict[section_name]: + raise ClinicaDLConfigurationError( + f"{key} option in {section_name} is not valid in TOML configuration file. " + f"Please see the documentation to see the list of option in TOML configuration file." + ) + + # task dependent + user_dict = remove_unused_tasks(user_dict, task) + + train_dict = dict() + # Fill train_dict from TOML files arguments + for config_section in user_dict: + for key in user_dict[config_section]: + train_dict[key] = user_dict[config_section][key] + return train_dict diff --git a/clinicadl/utils/cli_param/train_option.py b/clinicadl/utils/cli_param/train_option.py index c053277a5..f6a799eb6 100644 --- a/clinicadl/utils/cli_param/train_option.py +++ b/clinicadl/utils/cli_param/train_option.py @@ -1,7 +1,14 @@ +from typing import get_args + import click +from clinicadl.train.tasks.base_training_config import BaseTaskConfig +from clinicadl.train.tasks.classification_config import ClassificationConfig +from clinicadl.train.tasks.reconstruction_config import ReconstructionConfig +from clinicadl.train.tasks.regression_config import RegressionConfig from clinicadl.utils import cli_param +# Arguments caps_directory = cli_param.argument.caps_directory preprocessing_json = cli_param.argument.preprocessing_json tsv_directory = click.argument( @@ -9,382 +16,417 @@ type=click.Path(exists=True), ) output_maps = cli_param.argument.output_maps -# train option +# Config file config_file = click.option( "--config_file", "-c", type=click.Path(exists=True), help="Path to the TOML or JSON file containing the values of the options needed for training.", ) + +# Options # +base_config = BaseTaskConfig.model_fields +classification_config = ClassificationConfig.model_fields +regression_config = RegressionConfig.model_fields +reconstruction_config = ReconstructionConfig.model_fields + # Computational gpu = cli_param.option_group.computational_group.option( "--gpu/--no-gpu", - type=bool, - default=None, + default=base_config["gpu"].default, help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", + show_default=True, ) n_proc = cli_param.option_group.computational_group.option( "-np", "--n_proc", - type=int, - # default=2, + type=base_config["n_proc"].annotation, + default=base_config["n_proc"].default, help="Number of cores used during the task.", + show_default=True, ) batch_size = cli_param.option_group.computational_group.option( "--batch_size", - type=int, - # default=2, + type=base_config["batch_size"].annotation, + default=base_config["batch_size"].default, help="Batch size for data loading.", + show_default=True, ) evaluation_steps = cli_param.option_group.computational_group.option( "--evaluation_steps", "-esteps", - type=int, - # default=0, + type=base_config["evaluation_steps"].annotation, + default=base_config["evaluation_steps"].default, help="Fix the number of iterations to perform before computing an evaluation. Default will only " "perform one evaluation at the end of each epoch.", + show_default=True, ) fully_sharded_data_parallel = cli_param.option_group.computational_group.option( "--fully_sharded_data_parallel", "-fsdp", - type=bool, is_flag=True, help="Enables Fully Sharded Data Parallel with Pytorch to save memory at the cost of communications. " "Currently this only enables ZeRO Stage 1 but will be entirely replaced by FSDP in a later patch, " "this flag is already set to FSDP to that the zero flag is never actually removed.", - default=False, ) - amp = cli_param.option_group.computational_group.option( "--amp/--no-amp", - type=bool, + default=base_config["amp"].default, help="Enables automatic mixed precision during training and inference.", + show_default=True, ) # Reproducibility seed = cli_param.option_group.reproducibility_group.option( "--seed", + type=base_config["seed"].annotation, + default=base_config["seed"].default, help="Value to set the seed for all random operations." "Default will sample a random value for the seed.", - # default=None, - type=int, + show_default=True, ) deterministic = cli_param.option_group.reproducibility_group.option( "--deterministic/--nondeterministic", - type=bool, - default=None, + default=base_config["deterministic"].default, help="Forces Pytorch to be deterministic even when using a GPU. " "Will raise a RuntimeError if a non-deterministic function is encountered.", + show_default=True, ) compensation = cli_param.option_group.reproducibility_group.option( "--compensation", + type=click.Choice(list(base_config["compensation"].annotation)), + default=base_config["compensation"].default.value, help="Allow the user to choose how CUDA will compensate the deterministic behaviour.", - # default="memory", - type=click.Choice(["memory", "time"]), + show_default=True, ) save_all_models = cli_param.option_group.reproducibility_group.option( "--save_all_models/--save_only_best_model", - type=bool, + type=base_config["save_all_models"].annotation, + default=base_config["save_all_models"].default, help="If provided, enables the saving of models weights for each epochs.", + show_default=True, ) - # Model -architecture = cli_param.option_group.model_group.option( - "-a", - "--architecture", - type=str, - # default=0, - help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", -) multi_network = cli_param.option_group.model_group.option( "--multi_network/--single_network", - type=bool, - default=None, + default=base_config["multi_network"].default, help="If provided uses a multi-network framework.", + show_default=True, ) ssda_network = cli_param.option_group.model_group.option( "--ssda_network/--single_network", - type=bool, - default=None, + default=base_config["ssda_network"].default, help="If provided uses a ssda-network framework.", + show_default=True, ) # Task -label = cli_param.option_group.task_group.option( +classification_architecture = cli_param.option_group.model_group.option( + "-a", + "--architecture", + type=classification_config["architecture"].annotation, + default=classification_config["architecture"].default, + help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", +) +regression_architecture = cli_param.option_group.model_group.option( + "-a", + "--architecture", + type=regression_config["architecture"].annotation, + default=regression_config["architecture"].default, + help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", +) +reconstruction_architecture = cli_param.option_group.model_group.option( + "-a", + "--architecture", + type=reconstruction_config["architecture"].annotation, + default=reconstruction_config["architecture"].default, + help="Architecture of the chosen model to train. A set of model is available in ClinicaDL, default architecture depends on the NETWORK_TASK (see the documentation for more information).", +) +classification_label = cli_param.option_group.task_group.option( + "--label", + type=classification_config["label"].annotation, + default=classification_config["label"].default, + help="Target label used for training.", + show_default=True, +) +regression_label = cli_param.option_group.task_group.option( "--label", - type=str, + type=regression_config["label"].annotation, + default=regression_config["label"].default, help="Target label used for training.", + show_default=True, ) -selection_metrics = cli_param.option_group.task_group.option( +classification_selection_metrics = cli_param.option_group.task_group.option( "--selection_metrics", "-sm", multiple=True, + type=get_args(classification_config["selection_metrics"].annotation)[0], + default=classification_config["selection_metrics"].default, help="""Allow to save a list of models based on their selection metric. Default will only save the best model selected on loss.""", + show_default=True, +) +regression_selection_metrics = cli_param.option_group.task_group.option( + "--selection_metrics", + "-sm", + multiple=True, + type=get_args(regression_config["selection_metrics"].annotation)[0], + default=regression_config["selection_metrics"].default, + help="""Allow to save a list of models based on their selection metric. Default will + only save the best model selected on loss.""", + show_default=True, +) +reconstruction_selection_metrics = cli_param.option_group.task_group.option( + "--selection_metrics", + "-sm", + multiple=True, + type=get_args(reconstruction_config["selection_metrics"].annotation)[0], + default=reconstruction_config["selection_metrics"].default, + help="""Allow to save a list of models based on their selection metric. Default will + only save the best model selected on loss.""", + show_default=True, ) selection_threshold = cli_param.option_group.task_group.option( "--selection_threshold", - type=float, - # default=0, + type=classification_config["selection_threshold"].annotation, + default=classification_config["selection_threshold"].default, help="""Selection threshold for soft-voting. Will only be used if num_networks > 1.""", + show_default=True, ) classification_loss = cli_param.option_group.task_group.option( "--loss", "-l", - type=click.Choice(["CrossEntropyLoss", "MultiMarginLoss"]), + type=click.Choice(ClassificationConfig.get_compatible_losses()), + default=classification_config["loss"].default, help="Loss used by the network to optimize its training task.", + show_default=True, ) regression_loss = cli_param.option_group.task_group.option( "--loss", "-l", - type=click.Choice( - [ - "L1Loss", - "MSELoss", - "KLDivLoss", - "BCEWithLogitsLoss", - "HuberLoss", - "SmoothL1Loss", - ] - ), + type=click.Choice(RegressionConfig.get_compatible_losses()), + default=regression_config["loss"].default, help="Loss used by the network to optimize its training task.", + show_default=True, ) reconstruction_loss = cli_param.option_group.task_group.option( "--loss", "-l", - type=click.Choice( - [ - "L1Loss", - "MSELoss", - "KLDivLoss", - "BCEWithLogitsLoss", - "HuberLoss", - "SmoothL1Loss", - ] - ), + type=click.Choice(ReconstructionConfig.get_compatible_losses()), + default=reconstruction_config["loss"].default, help="Loss used by the network to optimize its training task.", + show_default=True, ) # Data multi_cohort = cli_param.option_group.data_group.option( "--multi_cohort/--single_cohort", - type=bool, - default=None, + default=base_config["multi_cohort"].default, help="Performs multi-cohort training. In this case, caps_dir and tsv_path must be paths to TSV files.", + show_default=True, ) diagnoses = cli_param.option_group.data_group.option( "--diagnoses", "-d", - type=str, - # default=(), + type=get_args(base_config["diagnoses"].annotation)[0], + default=base_config["diagnoses"].default, multiple=True, help="List of diagnoses used for training.", + show_default=True, ) baseline = cli_param.option_group.data_group.option( "--baseline/--longitudinal", - type=bool, - default=None, + default=base_config["baseline"].default, help="If provided, only the baseline sessions are used for training.", + show_default=True, ) valid_longitudinal = cli_param.option_group.data_group.option( "--valid_longitudinal/--valid_baseline", - type=bool, - default=None, + default=base_config["valid_longitudinal"].default, help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", + show_default=True, ) normalize = cli_param.option_group.data_group.option( "--normalize/--unnormalize", - type=bool, - default=None, + default=base_config["normalize"].default, help="Disable default MinMaxNormalization.", + show_default=True, ) data_augmentation = cli_param.option_group.data_group.option( "--data_augmentation", "-da", - type=click.Choice( - [ - "None", - "Noise", - "Erasing", - "CropPad", - "Smoothing", - "Motion", - "Ghosting", - "Spike", - "BiasField", - "RandomBlur", - "RandomSwap", - ] - ), - # default=(), + type=click.Choice(BaseTaskConfig.get_available_transforms()), + default=list(base_config["data_augmentation"].default), multiple=True, help="Randomly applies transforms on the training set.", + show_default=True, ) sampler = cli_param.option_group.data_group.option( "--sampler", "-s", - type=click.Choice(["random", "weighted"]), - # default="random", + type=click.Choice(list(base_config["sampler"].annotation)), + default=base_config["sampler"].default.value, help="Sampler used to load the training data set.", + show_default=True, ) caps_target = cli_param.option_group.data_group.option( "--caps_target", "-d", - type=str, - default=None, + type=base_config["caps_target"].annotation, + default=base_config["caps_target"].default, help="CAPS of target data.", + show_default=True, ) tsv_target_lab = cli_param.option_group.data_group.option( "--tsv_target_lab", "-d", - type=str, - default=None, + type=base_config["tsv_target_lab"].annotation, + default=base_config["tsv_target_lab"].default, help="TSV of labeled target data.", + show_default=True, ) tsv_target_unlab = cli_param.option_group.data_group.option( "--tsv_target_unlab", "-d", - type=str, - default=None, + type=base_config["tsv_target_unlab"].annotation, + default=base_config["tsv_target_unlab"].default, help="TSV of unllabeled target data.", + show_default=True, ) -preprocessing_dict_target = cli_param.option_group.data_group.option( +preprocessing_dict_target = cli_param.option_group.data_group.option( # TODO : change that name, it is not a dict. "--preprocessing_dict_target", "-d", - type=str, - default=None, + type=base_config["preprocessing_dict_target"].annotation, + default=base_config["preprocessing_dict_target"].default, help="Path to json target.", + show_default=True, ) # Cross validation n_splits = cli_param.option_group.cross_validation.option( "--n_splits", - type=int, - # default=0, + type=base_config["n_splits"].annotation, + default=base_config["n_splits"].default, help="If a value is given for k will load data of a k-fold CV. " "Default value (0) will load a single split.", + show_default=True, ) split = cli_param.option_group.cross_validation.option( "--split", "-s", - type=int, - # default=(), + type=get_args(base_config["split"].annotation)[0], + default=base_config["split"].default, multiple=True, help="Train the list of given splits. By default, all the splits are trained.", + show_default=True, ) # Optimization optimizer = cli_param.option_group.optimization_group.option( "--optimizer", - type=click.Choice( - [ - "Adadelta", - "Adagrad", - "Adam", - "AdamW", - "Adamax", - "ASGD", - "NAdam", - "RAdam", - "RMSprop", - "SGD", - ] - ), + type=click.Choice(BaseTaskConfig.get_available_optimizers()), + default=base_config["optimizer"].default, help="Optimizer used to train the network.", + show_default=True, ) epochs = cli_param.option_group.optimization_group.option( "--epochs", - type=int, - # default=20, + type=base_config["epochs"].annotation, + default=base_config["epochs"].default, help="Maximum number of epochs.", + show_default=True, ) learning_rate = cli_param.option_group.optimization_group.option( "--learning_rate", "-lr", - type=float, - # default=1e-4, + type=base_config["learning_rate"].annotation, + default=base_config["learning_rate"].default, help="Learning rate of the optimization.", + show_default=True, ) adaptive_learning_rate = cli_param.option_group.optimization_group.option( "--adaptive_learning_rate", "-alr", - type=bool, - help="Whether to diminish the learning rate", is_flag=True, - default=False, + help="Whether to diminish the learning rate", ) weight_decay = cli_param.option_group.optimization_group.option( "--weight_decay", "-wd", - type=float, - # default=1e-4, + type=base_config["weight_decay"].annotation, + default=base_config["weight_decay"].default, help="Weight decay value used in optimization.", + show_default=True, ) dropout = cli_param.option_group.optimization_group.option( "--dropout", - type=float, - # default=0, + type=base_config["dropout"].annotation, + default=base_config["dropout"].default, help="Rate value applied to dropout layers in a CNN architecture.", + show_default=True, ) patience = cli_param.option_group.optimization_group.option( "--patience", - type=int, - # default=0, + type=base_config["patience"].annotation, + default=base_config["patience"].default, help="Number of epochs for early stopping patience.", + show_default=True, ) tolerance = cli_param.option_group.optimization_group.option( "--tolerance", - type=float, - # default=0.0, + type=base_config["tolerance"].annotation, + default=base_config["tolerance"].default, help="Value for early stopping tolerance.", + show_default=True, ) accumulation_steps = cli_param.option_group.optimization_group.option( "--accumulation_steps", "-asteps", - type=int, - # default=1, + type=base_config["accumulation_steps"].annotation, + default=base_config["accumulation_steps"].default, help="Accumulates gradients during the given number of iterations before performing the weight update " "in order to virtually increase the size of the batch.", + show_default=True, ) profiler = cli_param.option_group.optimization_group.option( "--profiler/--no-profiler", - type=bool, + default=base_config["profiler"].default, help="Use `--profiler` to enable Pytorch profiler for the first 30 steps after a short warmup. " "It will make an execution trace and some statistics about the CPU and GPU usage.", + show_default=True, ) track_exp = cli_param.option_group.optimization_group.option( "--track_exp", "-te", - type=click.Choice( - [ - "wandb", - "mlflow", - "", - ] - ), + type=click.Choice(list(get_args(base_config["track_exp"].annotation)[0])), + default=base_config["track_exp"].default, help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.", + show_default=True, ) -# transfer learning +# Transfer Learning transfer_path = cli_param.option_group.transfer_learning_group.option( "-tp", "--transfer_path", - type=click.Path(), - # default=0.0, + type=get_args(base_config["transfer_path"].annotation)[0], + default=base_config["transfer_path"].default, help="Path of to a MAPS used for transfer learning.", + show_default=True, ) transfer_selection_metric = cli_param.option_group.transfer_learning_group.option( "-tsm", "--transfer_selection_metric", - type=str, - # default="loss", + type=base_config["transfer_selection_metric"].annotation, + default=base_config["transfer_selection_metric"].default, help="Metric used to select the model for transfer learning in the MAPS defined by transfer_path.", + show_default=True, ) nb_unfrozen_layer = cli_param.option_group.transfer_learning_group.option( "-nul", "--nb_unfrozen_layer", - type=int, - default=0, + type=base_config["nb_unfrozen_layer"].annotation, + default=base_config["nb_unfrozen_layer"].default, help="Number of layer that will be retrain during training. For example, if it is 2, the last two layers of the model will not be freezed.", + show_default=True, ) -# information +# Information emissions_calculator = cli_param.option_group.informations_group.option( "--calculate_emissions/--dont_calculate_emissions", - type=bool, - default=None, + default=base_config["emissions_calculator"].default, help="Flag to allow calculate the carbon emissions during training.", + show_default=True, ) diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/utils/maps_manager/maps_manager.py index 65351ee49..ee7f22065 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/utils/maps_manager/maps_manager.py @@ -529,6 +529,12 @@ def write_parameters(json_path: Path, parameters, verbose=True): if verbose: logger.info(f"Path of json file: {json_path}") + # temporary: to match CLI data. TODO : change CLI data + for parameter in parameters: + if parameters[parameter] == Path("."): + parameters[parameter] = "" + ############################### + with json_path.open(mode="w") as json_file: json.dump( parameters, json_file, skipkeys=True, indent=4, default=path_encoder diff --git a/clinicadl/utils/maps_manager/trainer/__init__.py b/clinicadl/utils/trainer/__init__.py similarity index 100% rename from clinicadl/utils/maps_manager/trainer/__init__.py rename to clinicadl/utils/trainer/__init__.py diff --git a/clinicadl/utils/maps_manager/trainer/trainer.py b/clinicadl/utils/trainer/trainer.py similarity index 99% rename from clinicadl/utils/maps_manager/trainer/trainer.py rename to clinicadl/utils/trainer/trainer.py index 1949e94df..816453ec3 100644 --- a/clinicadl/utils/maps_manager/trainer/trainer.py +++ b/clinicadl/utils/trainer/trainer.py @@ -26,7 +26,7 @@ from clinicadl.utils.callbacks.callbacks import Callback from clinicadl.utils.maps_manager import MapsManager -logger = getLogger("clinicadl.maps_manager") +logger = getLogger("clinicadl.trainer") class Trainer: diff --git a/clinicadl/utils/trainer/trainer_utils.py b/clinicadl/utils/trainer/trainer_utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/utils/trainer/training_config.py b/clinicadl/utils/trainer/training_config.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/train/tasks/test_base_training_config.py b/tests/unittests/train/tasks/test_base_training_config.py new file mode 100644 index 000000000..bdb923625 --- /dev/null +++ b/tests/unittests/train/tasks/test_base_training_config.py @@ -0,0 +1,80 @@ +import pytest +from pydantic import ValidationError + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "dropout": 1.1, + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "optimizer": "abc", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "data_augmentation": ("abc",), + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "diagnoses": "AD", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "size_reduction_factor": 1, + }, + ], +) +def test_fails_validations(parameters): + from clinicadl.train.tasks.base_training_config import BaseTaskConfig + + with pytest.raises(ValidationError): + BaseTaskConfig(**parameters) + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "diagnoses": ("AD", "CN"), + "optimizer": "Adam", + "dropout": 0.5, + "data_augmentation": ("Noise",), + "size_reduction_factor": 2, + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "diagnoses": ["AD", "CN"], + "data_augmentation": False, + "transfer_path": False, + }, + ], +) +def test_passes_validations(parameters): + from clinicadl.train.tasks.base_training_config import BaseTaskConfig + + BaseTaskConfig(**parameters) diff --git a/tests/unittests/train/tasks/test_classification_config.py b/tests/unittests/train/tasks/test_classification_config.py new file mode 100644 index 000000000..8bc28f1a2 --- /dev/null +++ b/tests/unittests/train/tasks/test_classification_config.py @@ -0,0 +1,62 @@ +import pytest +from pydantic import ValidationError + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_threshold": 1.1, + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "abc", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": "loss", + }, + ], +) +def test_fails_validations(parameters): + from clinicadl.train.tasks.classification_config import ClassificationConfig + + with pytest.raises(ValidationError): + ClassificationConfig(**parameters) + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "CrossEntropyLoss", + "selection_threshold": 0.5, + "selection_metrics": ("loss",), + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": ["loss"], + }, + ], +) +def test_passes_validations(parameters): + from clinicadl.train.tasks.classification_config import ClassificationConfig + + ClassificationConfig(**parameters) diff --git a/tests/unittests/train/tasks/test_reconstruction_config.py b/tests/unittests/train/tasks/test_reconstruction_config.py new file mode 100644 index 000000000..57b063f32 --- /dev/null +++ b/tests/unittests/train/tasks/test_reconstruction_config.py @@ -0,0 +1,62 @@ +import pytest +from pydantic import ValidationError + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "abc", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": "loss", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "normalization": "abc", + }, + ], +) +def test_fails_validations(parameters): + from clinicadl.train.tasks.reconstruction_config import ReconstructionConfig + + with pytest.raises(ValidationError): + ReconstructionConfig(**parameters) + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "L1Loss", + "selection_metrics": ("loss",), + "normalization": "batch", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": ["loss"], + }, + ], +) +def test_passes_validations(parameters): + from clinicadl.train.tasks.reconstruction_config import ReconstructionConfig + + ReconstructionConfig(**parameters) diff --git a/tests/unittests/train/tasks/test_regression_config.py b/tests/unittests/train/tasks/test_regression_config.py new file mode 100644 index 000000000..0b6e971a3 --- /dev/null +++ b/tests/unittests/train/tasks/test_regression_config.py @@ -0,0 +1,54 @@ +import pytest +from pydantic import ValidationError + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "abc", + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": "loss", + }, + ], +) +def test_fails_validations(parameters): + from clinicadl.train.tasks.regression_config import RegressionConfig + + with pytest.raises(ValidationError): + RegressionConfig(**parameters) + + +@pytest.mark.parametrize( + "parameters", + [ + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "loss": "MSELoss", + "selection_metrics": ("loss",), + }, + { + "caps_directory": "", + "preprocessing_json": "", + "tsv_directory": "", + "output_maps_directory": "", + "selection_metrics": ["loss"], + }, + ], +) +def test_passes_validations(parameters): + from clinicadl.train.tasks.regression_config import RegressionConfig + + RegressionConfig(**parameters) diff --git a/tests/unittests/train/test_train_utils.py b/tests/unittests/train/test_train_utils.py new file mode 100644 index 000000000..fefcc52d3 --- /dev/null +++ b/tests/unittests/train/test_train_utils.py @@ -0,0 +1,206 @@ +from pathlib import Path + +import pytest + +expected_classification = { + "architecture": "default", + "multi_network": False, + "ssda_network": False, + "dropout": 0.0, + "latent_space_size": 128, + "feature_size": 1024, + "n_conv": 4, + "io_layer_channels": 8, + "recons_weight": 1, + "kl_weight": 1, + "normalization": "batch", + "selection_metrics": ["loss"], + "label": "diagnosis", + "label_code": {}, + "selection_threshold": 0.0, + "loss": "CrossEntropyLoss", + "gpu": True, + "n_proc": 2, + "batch_size": 8, + "evaluation_steps": 0, + "fully_sharded_data_parallel": False, + "amp": False, + "seed": 0, + "deterministic": False, + "compensation": "memory", + "track_exp": "", + "transfer_path": False, + "transfer_selection_metric": "loss", + "nb_unfrozen_layer": 0, + "use_extracted_features": False, + "multi_cohort": False, + "diagnoses": ["AD", "CN"], + "baseline": False, + "valid_longitudinal": False, + "normalize": True, + "data_augmentation": [], + "sampler": "random", + "size_reduction": False, + "size_reduction_factor": 2, + "caps_target": "", + "tsv_target_lab": "", + "tsv_target_unlab": "", + "preprocessing_dict_target": "", + "n_splits": 0, + "split": [], + "optimizer": "Adam", + "epochs": 20, + "learning_rate": 1e-4, + "adaptive_learning_rate": False, + "weight_decay": 1e-4, + "patience": 0, + "tolerance": 0.0, + "accumulation_steps": 1, + "profiler": False, + "save_all_models": False, + "emissions_calculator": False, +} +expected_regression = { + "architecture": "default", + "multi_network": False, + "ssda_network": False, + "dropout": 0.0, + "latent_space_size": 128, + "feature_size": 1024, + "n_conv": 4, + "io_layer_channels": 8, + "recons_weight": 1, + "kl_weight": 1, + "normalization": "batch", + "selection_metrics": ["loss"], + "label": "age", + "loss": "MSELoss", + "gpu": True, + "n_proc": 2, + "batch_size": 8, + "evaluation_steps": 0, + "fully_sharded_data_parallel": False, + "amp": False, + "seed": 0, + "deterministic": False, + "compensation": "memory", + "track_exp": "", + "transfer_path": False, + "transfer_selection_metric": "loss", + "nb_unfrozen_layer": 0, + "use_extracted_features": False, + "multi_cohort": False, + "diagnoses": ["AD", "CN"], + "baseline": False, + "valid_longitudinal": False, + "normalize": True, + "data_augmentation": [], + "sampler": "random", + "size_reduction": False, + "size_reduction_factor": 2, + "caps_target": "", + "tsv_target_lab": "", + "tsv_target_unlab": "", + "preprocessing_dict_target": "", + "n_splits": 0, + "split": [], + "optimizer": "Adam", + "epochs": 20, + "learning_rate": 1e-4, + "adaptive_learning_rate": False, + "weight_decay": 1e-4, + "patience": 0, + "tolerance": 0.0, + "accumulation_steps": 1, + "profiler": False, + "save_all_models": False, + "emissions_calculator": False, +} +expected_reconstruction = { + "architecture": "default", + "multi_network": False, + "ssda_network": False, + "dropout": 0.0, + "latent_space_size": 128, + "feature_size": 1024, + "n_conv": 4, + "io_layer_channels": 8, + "recons_weight": 1, + "kl_weight": 1, + "normalization": "batch", + "selection_metrics": ["loss"], + "loss": "MSELoss", + "gpu": True, + "n_proc": 2, + "batch_size": 8, + "evaluation_steps": 0, + "fully_sharded_data_parallel": False, + "amp": False, + "seed": 0, + "deterministic": False, + "compensation": "memory", + "track_exp": "", + "transfer_path": False, + "transfer_selection_metric": "loss", + "nb_unfrozen_layer": 0, + "use_extracted_features": False, + "multi_cohort": False, + "diagnoses": ["AD", "CN"], + "baseline": False, + "valid_longitudinal": False, + "normalize": True, + "data_augmentation": [], + "sampler": "random", + "size_reduction": False, + "size_reduction_factor": 2, + "caps_target": "", + "tsv_target_lab": "", + "tsv_target_unlab": "", + "preprocessing_dict_target": "", + "n_splits": 0, + "split": [], + "optimizer": "Adam", + "epochs": 20, + "learning_rate": 1e-4, + "adaptive_learning_rate": False, + "weight_decay": 1e-4, + "patience": 0, + "tolerance": 0.0, + "accumulation_steps": 1, + "profiler": False, + "save_all_models": False, + "emissions_calculator": False, +} +clinicadl_root_dir = Path(__file__).parents[3] / "clinicadl" +config_toml = clinicadl_root_dir / "resources" / "config" / "train_config.toml" + + +@pytest.mark.parametrize( + "config_file,task,expected_output", + [ + (config_toml, "classification", expected_classification), + (config_toml, "regression", expected_regression), + (config_toml, "reconstruction", expected_reconstruction), + ], +) +def test_extract_config_from_file(config_file, task, expected_output): + from clinicadl.train.train_utils import extract_config_from_toml_file + + assert extract_config_from_toml_file(config_file, task) == expected_output + + +@pytest.mark.parametrize( + "config_file,task,expected_output", + [ + (config_toml, "classification", expected_classification), + ], +) +def test_extract_config_from_file_exceptions(config_file, task, expected_output): + from clinicadl.train.train_utils import extract_config_from_toml_file + from clinicadl.utils.exceptions import ClinicaDLConfigurationError + + with pytest.raises(ClinicaDLConfigurationError): + extract_config_from_toml_file( + Path(str(config_file).replace(".toml", ".json")), + task, + )