diff --git a/clinicadl/config/config/__init__.py b/clinicadl/config/config/__init__.py index baf410a23..4f59e5e04 100644 --- a/clinicadl/config/config/__init__.py +++ b/clinicadl/config/config/__init__.py @@ -5,7 +5,6 @@ from .data import DataConfig from .dataloader import DataLoaderConfig from .early_stopping import EarlyStoppingConfig -from .interpret import InterpretConfig from .lr_scheduler import LRschedulerConfig from .maps_manager import MapsManagerConfig from .modality import ( @@ -17,7 +16,6 @@ from .model import ModelConfig from .optimization import OptimizationConfig from .optimizer import OptimizerConfig -from .predict import PredictConfig from .preprocessing import ( PreprocessingConfig, PreprocessingImageConfig, diff --git a/clinicadl/config/config/computational.py b/clinicadl/config/config/computational.py index 43f7c66e9..5112b48ef 100644 --- a/clinicadl/config/config/computational.py +++ b/clinicadl/config/config/computational.py @@ -1,6 +1,10 @@ from logging import getLogger -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self + +from clinicadl.utils.cmdline_utils import check_gpu +from clinicadl.utils.exceptions import ClinicaDLArgumentError logger = getLogger("clinicadl.computational_config") @@ -13,3 +17,13 @@ class ComputationalConfig(BaseModel): gpu: bool = True # pydantic config model_config = ConfigDict(validate_assignment=True) + + @model_validator(mode="after") + def validator_gpu(self) -> Self: + if self.gpu: + check_gpu() + elif self.amp: + raise ClinicaDLArgumentError( + "AMP is designed to work with modern GPUs. Please add the --gpu flag." + ) + return self diff --git a/clinicadl/config/config/cross_validation.py b/clinicadl/config/config/cross_validation.py index 1aa222e98..fd2b4cb40 100644 --- a/clinicadl/config/config/cross_validation.py +++ b/clinicadl/config/config/cross_validation.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt +from clinicadl.utils.maps_manager.maps_manager import MapsManager + logger = getLogger("clinicadl.cross_validation_config") @@ -19,7 +21,7 @@ class CrossValidationConfig( n_splits: NonNegativeInt = 0 split: Optional[Tuple[NonNegativeInt, ...]] = None - tsv_directory: Path + tsv_directory: Optional[Path] = None # not needed in predict ? # pydantic config model_config = ConfigDict(validate_assignment=True) @@ -28,3 +30,9 @@ def validator_split(cls, v): if isinstance(v, list): return tuple(v) return v # TODO : check that split exists (and check coherence with n_splits) + + def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager): + # TEMPORARY + if not self.split: + self.split = maps_manager._find_splits() + logger.debug(f"List of splits {self.split}") diff --git a/clinicadl/config/config/data.py b/clinicadl/config/config/data.py index f6216228a..c6004a7e5 100644 --- a/clinicadl/config/config/data.py +++ b/clinicadl/config/config/data.py @@ -6,6 +6,7 @@ from clinicadl.utils.caps_dataset.data import load_data_test from clinicadl.utils.enum import Mode +from clinicadl.utils.maps_manager.maps_manager import MapsManager from clinicadl.utils.preprocessing import read_preprocessing logger = getLogger("clinicadl.data_config") @@ -24,12 +25,17 @@ class DataConfig(BaseModel): # TODO : put in data module label: Optional[str] = None label_code: Dict[str, int] = {} multi_cohort: bool = False - preprocessing_json: Path + preprocessing_json: Optional[Path] = None data_tsv: Optional[Path] = None n_subjects: int = 300 # pydantic config model_config = ConfigDict(validate_assignment=True) + def adapt_data_with_maps_manager_info(self, maps_manager: MapsManager): + # TEMPORARY + if self.diagnoses is None or len(self.diagnoses) == 0: + self.diagnoses = maps_manager.diagnoses + def create_groupe_df(self): group_df = None if self.data_tsv is not None and self.data_tsv.is_file(): diff --git a/clinicadl/config/config/dataloader.py b/clinicadl/config/config/dataloader.py index cc01ba9a9..e2c02afa1 100644 --- a/clinicadl/config/config/dataloader.py +++ b/clinicadl/config/config/dataloader.py @@ -4,6 +4,7 @@ from pydantic.types import PositiveInt from clinicadl.utils.enum import Sampler +from clinicadl.utils.maps_manager.maps_manager import MapsManager logger = getLogger("clinicadl.dataloader_config") @@ -16,3 +17,11 @@ class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module sampler: Sampler = Sampler.RANDOM # pydantic config model_config = ConfigDict(validate_assignment=True) + + def adapt_dataloader_with_maps_manager_info(self, maps_manager: MapsManager): + # TEMPORARY + if not self.batch_size: + self.batch_size = maps_manager.batch_size + + if not self.n_proc: + self.n_proc = maps_manager.n_proc diff --git a/clinicadl/config/config/maps_manager.py b/clinicadl/config/config/maps_manager.py index 105bdad90..7641bbbef 100644 --- a/clinicadl/config/config/maps_manager.py +++ b/clinicadl/config/config/maps_manager.py @@ -26,17 +26,3 @@ def check_output_saving_nifti(self, network_task: str) -> None: raise ClinicaDLArgumentError( "Cannot save nifti if the network task is not reconstruction. Please remove --save_nifti option." ) - - def adapt_config_with_maps_manager_info(self, maps_manager: MapsManager): - if not self.split_list: - self.split_list = maps_manager._find_splits() - logger.debug(f"List of splits {self.split_list}") - - if self.diagnoses is None or len(self.diagnoses) == 0: - self.diagnoses = maps_manager.diagnoses - - if not self.batch_size: - self.batch_size = maps_manager.batch_size - - if not self.n_proc: - self.n_proc = maps_manager.n_proc diff --git a/clinicadl/config/config/pipelines/__init__.py b/clinicadl/config/config/pipelines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/config/config/generate.py b/clinicadl/config/config/pipelines/generate.py similarity index 100% rename from clinicadl/config/config/generate.py rename to clinicadl/config/config/pipelines/generate.py diff --git a/clinicadl/config/config/interpret.py b/clinicadl/config/config/pipelines/interpret.py similarity index 63% rename from clinicadl/config/config/interpret.py rename to clinicadl/config/config/pipelines/interpret.py index 53db7da35..11f7afcd0 100644 --- a/clinicadl/config/config/interpret.py +++ b/clinicadl/config/config/pipelines/interpret.py @@ -5,18 +5,28 @@ from pydantic import BaseModel, field_validator +from clinicadl.config.config import ( + ComputationalConfig, + CrossValidationConfig, + DataLoaderConfig, + MapsManagerConfig, + ValidationConfig, +) +from clinicadl.config.config import DataConfig as DataBaseConfig from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp from clinicadl.utils.caps_dataset.data import ( load_data_test, ) from clinicadl.utils.enum import InterpretationMethod -from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore -from clinicadl.utils.maps_manager.maps_manager import MapsManager # type: ignore -logger = getLogger("clinicadl.predict_config") +logger = getLogger("clinicadl.interpret_config") + + +class DataConfig(DataBaseConfig): + caps_directory: Optional[Path] = None -class InterpretConfig(BaseModel): +class InterpretBaseConfig(BaseModel): name: str method: InterpretationMethod = InterpretationMethod.GRADIENTS target_node: int = 0 @@ -38,3 +48,15 @@ def get_method(self) -> Gradients: return GradCam else: raise ValueError(f"The method {self.method.value} is not implemented") + + +class InterpretConfig( + MapsManagerConfig, + InterpretBaseConfig, + DataConfig, + ValidationConfig, + CrossValidationConfig, + ComputationalConfig, + DataLoaderConfig, +): + """Config class to perform Transfer Learning.""" diff --git a/clinicadl/config/config/predict.py b/clinicadl/config/config/pipelines/predict.py similarity index 50% rename from clinicadl/config/config/predict.py rename to clinicadl/config/config/pipelines/predict.py index ebe3cff87..a09931cd7 100644 --- a/clinicadl/config/config/predict.py +++ b/clinicadl/config/config/pipelines/predict.py @@ -2,15 +2,23 @@ from pydantic import BaseModel +from clinicadl.config.config.data import DataConfig as DataBaseConfig +from clinicadl.config.config.maps_manager import ( + MapsManagerConfig as MapsManagerBaseConfig, +) from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore +from ..computational import ComputationalConfig +from ..cross_validation import CrossValidationConfig +from ..dataloader import DataLoaderConfig +from ..validation import ValidationConfig + logger = getLogger("clinicadl.predict_config") -class PredictConfig(BaseModel): +class MapsManagerConfig(MapsManagerBaseConfig): save_tensor: bool = False save_latent_tensor: bool = False - use_labels: bool = True def check_output_saving_tensor(self, network_task: str) -> None: # Check if task is reconstruction for "save_tensor" and "save_nifti" @@ -18,3 +26,18 @@ def check_output_saving_tensor(self, network_task: str) -> None: raise ClinicaDLArgumentError( "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." ) + + +class DataConfig(DataBaseConfig): + use_labels: bool = True + + +class PredictConfig( + MapsManagerConfig, + DataConfig, + ValidationConfig, + CrossValidationConfig, + ComputationalConfig, + DataLoaderConfig, +): + """Config class to perform Transfer Learning.""" diff --git a/clinicadl/config/config/pipelines/task/__init__.py b/clinicadl/config/config/pipelines/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/train/tasks/classification/config.py b/clinicadl/config/config/pipelines/task/classification.py similarity index 95% rename from clinicadl/train/tasks/classification/config.py rename to clinicadl/config/config/pipelines/task/classification.py index b30a4e266..9b89cf40c 100644 --- a/clinicadl/train/tasks/classification/config.py +++ b/clinicadl/config/config/pipelines/task/classification.py @@ -6,7 +6,7 @@ from clinicadl.config.config import DataConfig as BaseDataConfig from clinicadl.config.config import ModelConfig as BaseModelConfig from clinicadl.config.config import ValidationConfig as BaseValidationConfig -from clinicadl.train.trainer.training_config import TrainingConfig +from clinicadl.config.config.pipelines.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task logger = getLogger("clinicadl.classification_config") @@ -57,7 +57,7 @@ def list_to_tuples(cls, v): return v -class ClassificationConfig(TrainingConfig): +class ClassificationConfig(TrainConfig): """ Config class for the training of a classification model. diff --git a/clinicadl/train/tasks/reconstruction/config.py b/clinicadl/config/config/pipelines/task/reconstruction.py similarity index 94% rename from clinicadl/train/tasks/reconstruction/config.py rename to clinicadl/config/config/pipelines/task/reconstruction.py index 2492a6c49..b1f63b030 100644 --- a/clinicadl/train/tasks/reconstruction/config.py +++ b/clinicadl/config/config/pipelines/task/reconstruction.py @@ -6,7 +6,7 @@ from clinicadl.config.config import ModelConfig as BaseModelConfig from clinicadl.config.config import ValidationConfig as BaseValidationConfig -from clinicadl.train.trainer.training_config import TrainingConfig +from clinicadl.config.config.pipelines.train import TrainConfig from clinicadl.utils.enum import ( Normalization, ReconstructionLoss, @@ -47,7 +47,7 @@ def list_to_tuples(cls, v): return v -class ReconstructionConfig(TrainingConfig): +class ReconstructionConfig(TrainConfig): """ Config class for the training of a reconstruction model. diff --git a/clinicadl/train/tasks/regression/config.py b/clinicadl/config/config/pipelines/task/regression.py similarity index 94% rename from clinicadl/train/tasks/regression/config.py rename to clinicadl/config/config/pipelines/task/regression.py index 39cb59f03..1b6b49018 100644 --- a/clinicadl/train/tasks/regression/config.py +++ b/clinicadl/config/config/pipelines/task/regression.py @@ -7,7 +7,7 @@ from clinicadl.config.config import DataConfig as BaseDataConfig from clinicadl.config.config import ModelConfig as BaseModelConfig from clinicadl.config.config import ValidationConfig as BaseValidationConfig -from clinicadl.train.trainer.training_config import TrainingConfig +from clinicadl.config.config.pipelines.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task logger = getLogger("clinicadl.reconstruction_config") @@ -47,7 +47,7 @@ def list_to_tuples(cls, v): return v -class RegressionConfig(TrainingConfig): +class RegressionConfig(TrainConfig): """ Config class for the training of a regression model. diff --git a/clinicadl/train/trainer/training_config.py b/clinicadl/config/config/pipelines/train.py similarity index 98% rename from clinicadl/train/trainer/training_config.py rename to clinicadl/config/config/pipelines/train.py index 5ae2a3ec6..0cd23773b 100644 --- a/clinicadl/train/trainer/training_config.py +++ b/clinicadl/config/config/pipelines/train.py @@ -30,7 +30,7 @@ logger = getLogger("clinicadl.training_config") -class TrainingConfig(BaseModel, ABC): +class TrainConfig(BaseModel, ABC): """ Abstract config class for the training pipeline. diff --git a/clinicadl/config/options/__init__.py b/clinicadl/config/options/__init__.py index 610ac7a99..d7d50b584 100644 --- a/clinicadl/config/options/__init__.py +++ b/clinicadl/config/options/__init__.py @@ -1 +1 @@ -from .task import classification, reconstruction, regression +# from .task import classification, reconstruction, regression diff --git a/clinicadl/config/options/callbacks.py b/clinicadl/config/options/callbacks.py index 2e40c0d0b..86f84c6b8 100644 --- a/clinicadl/config/options/callbacks.py +++ b/clinicadl/config/options/callbacks.py @@ -1,21 +1,20 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.callbacks import CallbacksConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type emissions_calculator = click.option( "--calculate_emissions/--dont_calculate_emissions", - default=get_default("emissions_calculator", config.CallbacksConfig), + default=get_default("emissions_calculator", CallbacksConfig), help="Flag to allow calculate the carbon emissions during training.", show_default=True, ) track_exp = click.option( "--track_exp", "-te", - type=get_type("track_exp", config.CallbacksConfig), - default=get_default("track_exp", config.CallbacksConfig), + type=get_type("track_exp", CallbacksConfig), + default=get_default("track_exp", CallbacksConfig), help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.", show_default=True, ) diff --git a/clinicadl/config/options/computational.py b/clinicadl/config/options/computational.py index 17a0b9ac8..ba8c2adc6 100644 --- a/clinicadl/config/options/computational.py +++ b/clinicadl/config/options/computational.py @@ -1,13 +1,13 @@ import click -from clinicadl.config import config +from clinicadl.config.config.computational import ComputationalConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Computational amp = click.option( "--amp/--no-amp", - default=get_default("amp", config.ComputationalConfig), + default=get_default("amp", ComputationalConfig), help="Enables automatic mixed precision during training and inference.", show_default=True, ) @@ -21,7 +21,7 @@ ) gpu = click.option( "--gpu/--no-gpu", - default=get_default("gpu", config.ComputationalConfig), + default=get_default("gpu", ComputationalConfig), help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", show_default=True, ) diff --git a/clinicadl/config/options/cross_validation.py b/clinicadl/config/options/cross_validation.py index b15b470fc..ec762a25f 100644 --- a/clinicadl/config/options/cross_validation.py +++ b/clinicadl/config/options/cross_validation.py @@ -1,15 +1,14 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.cross_validation import CrossValidationConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Cross Validation n_splits = click.option( "--n_splits", - type=get_type("n_splits", config.CrossValidationConfig), - default=get_default("n_splits", config.CrossValidationConfig), + type=get_type("n_splits", CrossValidationConfig), + default=get_default("n_splits", CrossValidationConfig), help="If a value is given for k will load data of a k-fold CV. " "Default value (0) will load a single split.", show_default=True, @@ -18,7 +17,7 @@ "--split", "-s", type=int, # get_type("split", config.CrossValidationConfig), - default=get_default("split", config.CrossValidationConfig), + default=get_default("split", CrossValidationConfig), multiple=True, help="Train the list of given splits. By default, all the splits are trained.", show_default=True, diff --git a/clinicadl/config/options/data.py b/clinicadl/config/options/data.py index 783eb5032..ae9d95fb9 100644 --- a/clinicadl/config/options/data.py +++ b/clinicadl/config/options/data.py @@ -1,49 +1,56 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.data import DataConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Data baseline = click.option( "--baseline/--longitudinal", - default=get_default("baseline", config.DataConfig), + default=get_default("baseline", DataConfig), help="If provided, only the baseline sessions are used for training.", show_default=True, ) diagnoses = click.option( "--diagnoses", "-d", - type=get_type("diagnoses", config.DataConfig), - default=get_default("diagnoses", config.DataConfig), + type=get_type("diagnoses", DataConfig), + default=get_default("diagnoses", DataConfig), multiple=True, help="List of diagnoses used for training.", show_default=True, ) multi_cohort = click.option( "--multi_cohort/--single_cohort", - default=get_default("multi_cohort", config.DataConfig), + default=get_default("multi_cohort", DataConfig), help="Performs multi-cohort training. In this case, caps_dir and tsv_path must be paths to TSV files.", show_default=True, ) participants_tsv = click.option( "--participants_tsv", - type=get_type("data_tsv", config.DataConfig), - default=get_default("data_tsv", config.DataConfig), + type=get_type("data_tsv", DataConfig), + default=get_default("data_tsv", DataConfig), help="Path to a TSV file including a list of participants/sessions.", show_default=True, ) n_subjects = click.option( "--n_subjects", - type=get_type("n_subjects", config.DataConfig), - default=get_default("n_subjects", config.DataConfig), + type=get_type("n_subjects", DataConfig), + default=get_default("n_subjects", DataConfig), help="Number of subjects in each class of the synthetic dataset.", ) caps_directory = click.option( "--caps_directory", - type=get_type("caps_directory", config.DataConfig), - default=get_default("caps_directory", config.DataConfig), + type=get_type("caps_directory", DataConfig), + default=get_default("caps_directory", DataConfig), help="Data using CAPS structure, if different from the one used during network training.", show_default=True, ) +label = click.option( + "--label", + type=get_type("label", DataConfig), + default=get_default("label", DataConfig), + show_default=True, + help="Target label used for training (if NETWORK_TASK in [`regression`, `classification`]). " + "Default will reuse the same label as during the training task.", +) diff --git a/clinicadl/config/options/dataloader.py b/clinicadl/config/options/dataloader.py index f1a596892..7392f6c0d 100644 --- a/clinicadl/config/options/dataloader.py +++ b/clinicadl/config/options/dataloader.py @@ -1,31 +1,30 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.dataloader import DataLoaderConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # DataLoader batch_size = click.option( "--batch_size", - type=get_type("batch_size", config.DataLoaderConfig), - default=get_default("batch_size", config.DataLoaderConfig), + type=get_type("batch_size", DataLoaderConfig), + default=get_default("batch_size", DataLoaderConfig), help="Batch size for data loading.", show_default=True, ) n_proc = click.option( "-np", "--n_proc", - type=get_type("n_proc", config.DataLoaderConfig), - default=get_default("n_proc", config.DataLoaderConfig), + type=get_type("n_proc", DataLoaderConfig), + default=get_default("n_proc", DataLoaderConfig), help="Number of cores used during the task.", show_default=True, ) sampler = click.option( "--sampler", "-s", - type=get_type("sampler", config.DataLoaderConfig), - default=get_default("sampler", config.DataLoaderConfig), + type=get_type("sampler", DataLoaderConfig), + default=get_default("sampler", DataLoaderConfig), help="Sampler used to load the training data set.", show_default=True, ) diff --git a/clinicadl/config/options/early_stopping.py b/clinicadl/config/options/early_stopping.py index e85cb7ba5..f4e7771e3 100644 --- a/clinicadl/config/options/early_stopping.py +++ b/clinicadl/config/options/early_stopping.py @@ -1,22 +1,21 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.early_stopping import EarlyStoppingConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Early Stopping patience = click.option( "--patience", - type=get_type("patience", config.EarlyStoppingConfig), - default=get_default("patience", config.EarlyStoppingConfig), + type=get_type("patience", EarlyStoppingConfig), + default=get_default("patience", EarlyStoppingConfig), help="Number of epochs for early stopping patience.", show_default=True, ) tolerance = click.option( "--tolerance", - type=get_type("tolerance", config.EarlyStoppingConfig), - default=get_default("tolerance", config.EarlyStoppingConfig), + type=get_type("tolerance", EarlyStoppingConfig), + default=get_default("tolerance", EarlyStoppingConfig), help="Value for early stopping tolerance.", show_default=True, ) diff --git a/clinicadl/config/options/interpret.py b/clinicadl/config/options/interpret.py index 4ad7fcc88..1c34035ab 100644 --- a/clinicadl/config/options/interpret.py +++ b/clinicadl/config/options/interpret.py @@ -1,30 +1,29 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.pipelines.interpret import InterpretConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # interpret specific name = click.argument( "name", - type=get_type("name", config.InterpretConfig), + type=get_type("name", InterpretConfig), ) method = click.argument( "method", - type=get_type("method", config.InterpretConfig), # ["gradients", "grad-cam"] + type=get_type("method", InterpretConfig), # ["gradients", "grad-cam"] ) level = click.option( "--level_grad_cam", - type=get_type("level", config.InterpretConfig), - default=get_default("level", config.InterpretConfig), + type=get_type("level", InterpretConfig), + default=get_default("level", InterpretConfig), help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", show_default=True, ) target_node = click.option( "--target_node", - type=get_type("target_node", config.InterpretConfig), - default=get_default("target_node", config.InterpretConfig), + type=get_type("target_node", InterpretConfig), + default=get_default("target_node", InterpretConfig), help="Which target node the gradients explain. Default takes the first output node.", show_default=True, ) diff --git a/clinicadl/config/options/maps_manager.py b/clinicadl/config/options/maps_manager.py index 7a2bd2100..6bc5fb198 100644 --- a/clinicadl/config/options/maps_manager.py +++ b/clinicadl/config/options/maps_manager.py @@ -1,14 +1,10 @@ import click -from clinicadl.config import config +from clinicadl.config.config.maps_manager import MapsManagerConfig from clinicadl.utils.config_utils import get_type_from_config_class as get_type -maps_dir = click.argument( - "maps_dir", type=get_type("maps_dir", config.MapsManagerConfig) -) -data_group = click.option( - "data_group", type=get_type("data_group", config.MapsManagerConfig) -) +maps_dir = click.argument("maps_dir", type=get_type("maps_dir", MapsManagerConfig)) +data_group = click.option("data_group", type=get_type("data_group", MapsManagerConfig)) overwrite = click.option( diff --git a/clinicadl/config/options/modality.py b/clinicadl/config/options/modality.py index e755acf85..fc2118db6 100644 --- a/clinicadl/config/options/modality.py +++ b/clinicadl/config/options/modality.py @@ -1,14 +1,17 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.modality import ( + CustomModalityConfig, + DTIModalityConfig, + PETModalityConfig, +) from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type tracer = click.option( "--tracer", - default=get_default("tracer", config.PETModalityConfig), - type=get_type("tracer", config.PETModalityConfig), + default=get_default("tracer", PETModalityConfig), + type=get_type("tracer", PETModalityConfig), help=( "Acquisition label if MODALITY is `pet-linear`. " "Name of the tracer used for the PET acquisition (trc-). " @@ -18,8 +21,8 @@ suvr_reference_region = click.option( "-suvr", "--suvr_reference_region", - default=get_default("suvr_reference_region", config.PETModalityConfig), - type=get_type("suvr_reference_region", config.PETModalityConfig), + default=get_default("suvr_reference_region", PETModalityConfig), + type=get_type("suvr_reference_region", PETModalityConfig), help=( "Regions used for normalization if MODALITY is `pet-linear`. " "Intensity normalization using the average PET uptake in reference regions resulting in a standardized uptake " @@ -30,8 +33,8 @@ custom_suffix = click.option( "-cn", "--custom_suffix", - default=get_default("custom_suffix", config.CustomModalityConfig), - type=get_type("custom_suffix", config.CustomModalityConfig), + default=get_default("custom_suffix", CustomModalityConfig), + type=get_type("custom_suffix", CustomModalityConfig), help=( "Suffix of output files if MODALITY is `custom`. " "Suffix to append to filenames, for instance " @@ -42,14 +45,14 @@ dti_measure = click.option( "--dti_measure", "-dm", - type=get_type("dti_measure", config.DTIModalityConfig), + type=get_type("dti_measure", DTIModalityConfig), help="Possible DTI measures.", - default=get_default("dti_measure", config.DTIModalityConfig), + default=get_default("dti_measure", DTIModalityConfig), ) dti_space = click.option( "--dti_space", "-ds", - type=get_type("dti_space", config.DTIModalityConfig), + type=get_type("dti_space", DTIModalityConfig), help="Possible DTI space.", - default=get_default("dti_space", config.DTIModalityConfig), + default=get_default("dti_space", DTIModalityConfig), ) diff --git a/clinicadl/config/options/model.py b/clinicadl/config/options/model.py index ecb6271a3..0d5ab4083 100644 --- a/clinicadl/config/options/model.py +++ b/clinicadl/config/options/model.py @@ -1,21 +1,20 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.model import ModelConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Model multi_network = click.option( "--multi_network/--single_network", - default=get_default("multi_network", config.ModelConfig), + default=get_default("multi_network", ModelConfig), help="If provided uses a multi-network framework.", show_default=True, ) dropout = click.option( "--dropout", - type=get_type("dropout", config.ModelConfig), - default=get_default("dropout", config.ModelConfig), + type=get_type("dropout", ModelConfig), + default=get_default("dropout", ModelConfig), help="Rate value applied to dropout layers in a CNN architecture.", show_default=True, ) diff --git a/clinicadl/config/options/optimization.py b/clinicadl/config/options/optimization.py index 80a4d61a8..82fec05ba 100644 --- a/clinicadl/config/options/optimization.py +++ b/clinicadl/config/options/optimization.py @@ -1,7 +1,6 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.optimization import OptimizationConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type @@ -9,22 +8,22 @@ accumulation_steps = click.option( "--accumulation_steps", "-asteps", - type=get_type("accumulation_steps", config.OptimizationConfig), - default=get_default("accumulation_steps", config.OptimizationConfig), + type=get_type("accumulation_steps", OptimizationConfig), + default=get_default("accumulation_steps", OptimizationConfig), help="Accumulates gradients during the given number of iterations before performing the weight update " "in order to virtually increase the size of the batch.", show_default=True, ) epochs = click.option( "--epochs", - type=get_type("epochs", config.OptimizationConfig), - default=get_default("epochs", config.OptimizationConfig), + type=get_type("epochs", OptimizationConfig), + default=get_default("epochs", OptimizationConfig), help="Maximum number of epochs.", show_default=True, ) profiler = click.option( "--profiler/--no-profiler", - default=get_default("profiler", config.OptimizationConfig), + default=get_default("profiler", OptimizationConfig), help="Use `--profiler` to enable Pytorch profiler for the first 30 steps after a short warmup. " "It will make an execution trace and some statistics about the CPU and GPU usage.", show_default=True, diff --git a/clinicadl/config/options/optimizer.py b/clinicadl/config/options/optimizer.py index fde8f2762..6143ac0de 100644 --- a/clinicadl/config/options/optimizer.py +++ b/clinicadl/config/options/optimizer.py @@ -1,7 +1,6 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.optimizer import OptimizerConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type @@ -9,23 +8,23 @@ learning_rate = click.option( "--learning_rate", "-lr", - type=get_type("learning_rate", config.OptimizerConfig), - default=get_default("learning_rate", config.OptimizerConfig), + type=get_type("learning_rate", OptimizerConfig), + default=get_default("learning_rate", OptimizerConfig), help="Learning rate of the optimization.", show_default=True, ) optimizer = click.option( "--optimizer", - type=get_type("optimizer", config.OptimizerConfig), - default=get_default("optimizer", config.OptimizerConfig), + type=get_type("optimizer", OptimizerConfig), + default=get_default("optimizer", OptimizerConfig), help="Optimizer used to train the network.", show_default=True, ) weight_decay = click.option( "--weight_decay", "-wd", - type=get_type("weight_decay", config.OptimizerConfig), - default=get_default("weight_decay", config.OptimizerConfig), + type=get_type("weight_decay", OptimizerConfig), + default=get_default("weight_decay", OptimizerConfig), help="Weight decay value used in optimization.", show_default=True, ) diff --git a/clinicadl/config/options/predict.py b/clinicadl/config/options/predict.py index 6c6ba04f3..15ab7a6d8 100644 --- a/clinicadl/config/options/predict.py +++ b/clinicadl/config/options/predict.py @@ -1,7 +1,6 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.pipelines.predict import PredictConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type @@ -9,17 +8,9 @@ use_labels = click.option( "--use_labels/--no_labels", show_default=True, - default=get_default("use_labels", config.PredictConfig), + default=get_default("use_labels", PredictConfig), help="Set this option to --no_labels if your dataset does not contain ground truth labels.", ) -label = click.option( - "--label", - type=get_type("label", config.PredictConfig), - default=get_default("label", config.PredictConfig), - show_default=True, - help="Target label used for training (if NETWORK_TASK in [`regression`, `classification`]). " - "Default will reuse the same label as during the training task.", -) save_tensor = click.option( "--save_tensor", is_flag=True, diff --git a/clinicadl/config/options/preprocessing.py b/clinicadl/config/options/preprocessing.py index 0cbe3d99f..470f650f7 100644 --- a/clinicadl/config/options/preprocessing.py +++ b/clinicadl/config/options/preprocessing.py @@ -1,14 +1,14 @@ import click -from clinicadl.config import config +from clinicadl.config.config.preprocessing import PreprocessingConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type extract_json = click.option( "-ej", "--extract_json", - type=get_type("extract_json", config.PreprocessingConfig), - default=get_default("extract_json", config.PreprocessingConfig), + type=get_type("extract_json", PreprocessingConfig), + default=get_default("extract_json", PreprocessingConfig), help="Name of the JSON file created to describe the tensor extraction. " "Default will use format extract_{time_stamp}.json", ) @@ -31,8 +31,8 @@ preprocessing = click.option( "--preprocessing", - type=get_type("preprocessing", config.PreprocessingConfig), - default=get_default("preprocessing", config.PreprocessingConfig), + type=get_type("preprocessing", PreprocessingConfig), + default=get_default("preprocessing", PreprocessingConfig), required=True, help="Preprocessing used to generate synthetic data.", show_default=True, @@ -42,16 +42,16 @@ patch_size = click.option( "-ps", "--patch_size", - type=get_type("patch_size", config.PreprocessingPatchConfig), - default=get_default("patch_size", config.PreprocessingPatchConfig), + type=get_type("patch_size", PreprocessingPatchConfig), + default=get_default("patch_size", PreprocessingPatchConfig), show_default=True, help="Patch size.", ) stride_size = click.option( "-ss", "--stride_size", - type=get_type("stride_size", config.PreprocessingPatchConfig), - default=get_default("stride_size", config.PreprocessingPatchConfig), + type=get_type("stride_size", PreprocessingPatchConfig), + default=get_default("stride_size", PreprocessingPatchConfig), show_default=True, help="Stride size.", ) @@ -60,16 +60,16 @@ slice_direction = click.option( "-sd", "--slice_direction", - type=get_type("slice_direction", config.PreprocessingSliceConfig), - default=get_default("slice_direction", config.PreprocessingSliceConfig), + type=get_type("slice_direction", PreprocessingSliceConfig), + default=get_default("slice_direction", PreprocessingSliceConfig), show_default=True, help="Slice direction. 0: Sagittal plane, 1: Coronal plane, 2: Axial plane.", ) slice_mode = click.option( "-sm", "--slice_mode", - type=get_type("slice_mode", config.PreprocessingSliceConfig), - default=get_default("slice_mode", config.PreprocessingSliceConfig), + type=get_type("slice_mode", PreprocessingSliceConfig), + default=get_default("slice_mode", PreprocessingSliceConfig), show_default=True, help=( "rgb: Save the slice in three identical channels, " @@ -79,8 +79,8 @@ discarded_slices = click.option( "-ds", "--discarded_slices", - type=get_type("discarded_slices", config.PreprocessingSliceConfig), - default=get_default("discarded_slices", config.PreprocessingSliceConfig), + type=get_type("discarded_slices", PreprocessingSliceConfig), + default=get_default("discarded_slices", PreprocessingSliceConfig), multiple=2, help="""Number of slices discarded from respectively the beginning and the end of the MRI volume. If only one argument is given, it will be @@ -90,16 +90,16 @@ roi_list = click.option( "--roi_list", - type=get_type("roi_list", config.PreprocessingROIConfig), - default=get_default("roi_list", config.PreprocessingROIConfig), + type=get_type("roi_list", PreprocessingROIConfig), + default=get_default("roi_list", PreprocessingROIConfig), required=True, multiple=True, help="List of regions to be extracted", ) roi_uncrop_output = click.option( "--roi_uncrop_output", - type=get_type("roi_uncrop_output", config.PreprocessingROIConfig), - default=get_default("roi_uncrop_output", config.PreprocessingROIConfig), + type=get_type("roi_uncrop_output", PreprocessingROIConfig), + default=get_default("roi_uncrop_output", PreprocessingROIConfig), is_flag=True, help="Disable cropping option so the output tensors " "have the same size than the whole image.", @@ -107,16 +107,16 @@ roi_custom_template = click.option( "--roi_custom_template", "-ct", - type=get_type("roi_custom_template", config.PreprocessingROIConfig), - default=get_default("roi_custom_template", config.PreprocessingROIConfig), + type=get_type("roi_custom_template", PreprocessingROIConfig), + default=get_default("roi_custom_template", PreprocessingROIConfig), help="""Template name if MODALITY is `custom`. Name of the template used for registration during the preprocessing procedure.""", ) roi_custom_mask_pattern = click.option( "--roi_custom_mask_pattern", "-cmp", - type=get_type("roi_custom_mask_pattern", config.PreprocessingROIConfig), - default=get_default("roi_custom_mask_pattern", config.PreprocessingROIConfig), + type=get_type("roi_custom_mask_pattern", PreprocessingROIConfig), + default=get_default("roi_custom_mask_pattern", PreprocessingROIConfig), help="""Mask pattern if MODALITY is `custom`. If given will select only the masks containing the string given. The mask with the shortest name is taken.""", diff --git a/clinicadl/config/options/reproducibility.py b/clinicadl/config/options/reproducibility.py index 02c7bd597..2e897607f 100644 --- a/clinicadl/config/options/reproducibility.py +++ b/clinicadl/config/options/reproducibility.py @@ -1,35 +1,35 @@ import click -from clinicadl.config import config +from clinicadl.config.config.reproducibility import ReproducibilityConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Reproducibility compensation = click.option( "--compensation", - type=get_type("compensation", config.ReproducibilityConfig), - default=get_default("compensation", config.ReproducibilityConfig), + type=get_type("compensation", ReproducibilityConfig), + default=get_default("compensation", ReproducibilityConfig), help="Allow the user to choose how CUDA will compensate the deterministic behaviour.", show_default=True, ) deterministic = click.option( "--deterministic/--nondeterministic", - default=get_default("deterministic", config.ReproducibilityConfig), + default=get_default("deterministic", ReproducibilityConfig), help="Forces Pytorch to be deterministic even when using a GPU. " "Will raise a RuntimeError if a non-deterministic function is encountered.", show_default=True, ) save_all_models = click.option( "--save_all_models/--save_only_best_model", - type=get_type("save_all_models", config.ReproducibilityConfig), - default=get_default("save_all_models", config.ReproducibilityConfig), + type=get_type("save_all_models", ReproducibilityConfig), + default=get_default("save_all_models", ReproducibilityConfig), help="If provided, enables the saving of models weights for each epochs.", show_default=True, ) seed = click.option( "--seed", - type=get_type("seed", config.ReproducibilityConfig), - default=get_default("seed", config.ReproducibilityConfig), + type=get_type("seed", ReproducibilityConfig), + default=get_default("seed", ReproducibilityConfig), help="Value to set the seed for all random operations." "Default will sample a random value for the seed.", show_default=True, @@ -37,7 +37,7 @@ config_file = click.option( "--config_file", "-c", - type=get_type("config_file", config.ReproducibilityConfig), - default=get_default("config_file", config.ReproducibilityConfig), + type=get_type("config_file", ReproducibilityConfig), + default=get_default("config_file", ReproducibilityConfig), help="Path to the TOML or JSON file containing the values of the options needed for training.", ) diff --git a/clinicadl/config/options/ssda.py b/clinicadl/config/options/ssda.py index 5f3db0953..8d5865311 100644 --- a/clinicadl/config/options/ssda.py +++ b/clinicadl/config/options/ssda.py @@ -1,7 +1,6 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.ssda import SSDAConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type @@ -9,38 +8,38 @@ caps_target = click.option( "--caps_target", "-d", - type=get_type("caps_target", config.SSDAConfig), - default=get_default("caps_target", config.SSDAConfig), + type=get_type("caps_target", SSDAConfig), + default=get_default("caps_target", SSDAConfig), help="CAPS of target data.", show_default=True, ) preprocessing_json_target = click.option( "--preprocessing_json_target", "-d", - type=get_type("preprocessing_json_target", config.SSDAConfig), - default=get_default("preprocessing_json_target", config.SSDAConfig), + type=get_type("preprocessing_json_target", SSDAConfig), + default=get_default("preprocessing_json_target", SSDAConfig), help="Path to json target.", show_default=True, ) ssda_network = click.option( "--ssda_network/--single_network", - default=get_default("ssda_network", config.SSDAConfig), + default=get_default("ssda_network", SSDAConfig), help="If provided uses a ssda-network framework.", show_default=True, ) tsv_target_lab = click.option( "--tsv_target_lab", "-d", - type=get_type("tsv_target_lab", config.SSDAConfig), - default=get_default("tsv_target_lab", config.SSDAConfig), + type=get_type("tsv_target_lab", SSDAConfig), + default=get_default("tsv_target_lab", SSDAConfig), help="TSV of labeled target data.", show_default=True, ) tsv_target_unlab = click.option( "--tsv_target_unlab", "-d", - type=get_type("tsv_target_unlab", config.SSDAConfig), - default=get_default("tsv_target_unlab", config.SSDAConfig), + type=get_type("tsv_target_unlab", SSDAConfig), + default=get_default("tsv_target_unlab", SSDAConfig), help="TSV of unllabeled target data.", show_default=True, ) diff --git a/clinicadl/config/options/task/__init__.py b/clinicadl/config/options/task/__init__.py index e69de29bb..c6653a77c 100644 --- a/clinicadl/config/options/task/__init__.py +++ b/clinicadl/config/options/task/__init__.py @@ -0,0 +1 @@ +from . import classification, reconstruction, regression diff --git a/clinicadl/config/options/task/classification.py b/clinicadl/config/options/task/classification.py index 8ee8289a3..f46e6d521 100644 --- a/clinicadl/config/options/task/classification.py +++ b/clinicadl/config/options/task/classification.py @@ -1,6 +1,6 @@ import click -from clinicadl.train.tasks.classification.config import ( +from clinicadl.config.config.pipelines.task.classification import ( DataConfig, ModelConfig, ValidationConfig, diff --git a/clinicadl/config/options/task/reconstruction.py b/clinicadl/config/options/task/reconstruction.py index 37146389d..7270c4fe9 100644 --- a/clinicadl/config/options/task/reconstruction.py +++ b/clinicadl/config/options/task/reconstruction.py @@ -1,6 +1,9 @@ import click -from clinicadl.train.tasks.reconstruction.config import ModelConfig, ValidationConfig +from clinicadl.config.config.pipelines.task.reconstruction import ( + ModelConfig, + ValidationConfig, +) from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type diff --git a/clinicadl/config/options/task/regression.py b/clinicadl/config/options/task/regression.py index 147fcfafd..af8a14c50 100644 --- a/clinicadl/config/options/task/regression.py +++ b/clinicadl/config/options/task/regression.py @@ -1,6 +1,6 @@ import click -from clinicadl.train.tasks.regression.config import ( +from clinicadl.config.config.pipelines.task.regression import ( DataConfig, ModelConfig, ValidationConfig, diff --git a/clinicadl/config/options/transfer_learning.py b/clinicadl/config/options/transfer_learning.py index 88a4c9de7..89817838e 100644 --- a/clinicadl/config/options/transfer_learning.py +++ b/clinicadl/config/options/transfer_learning.py @@ -1,30 +1,30 @@ import click -from clinicadl.config import config +from clinicadl.config.config.transfer_learning import TransferLearningConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type nb_unfrozen_layer = click.option( "-nul", "--nb_unfrozen_layer", - type=get_type("nb_unfrozen_layer", config.TransferLearningConfig), - default=get_default("nb_unfrozen_layer", config.TransferLearningConfig), + type=get_type("nb_unfrozen_layer", TransferLearningConfig), + default=get_default("nb_unfrozen_layer", TransferLearningConfig), help="Number of layer that will be retrain during training. For example, if it is 2, the last two layers of the model will not be freezed.", show_default=True, ) transfer_path = click.option( "-tp", "--transfer_path", - type=get_type("transfer_path", config.TransferLearningConfig), - default=get_default("transfer_path", config.TransferLearningConfig), + type=get_type("transfer_path", TransferLearningConfig), + default=get_default("transfer_path", TransferLearningConfig), help="Path of to a MAPS used for transfer learning.", show_default=True, ) transfer_selection_metric = click.option( "-tsm", "--transfer_selection_metric", - type=get_type("transfer_selection_metric", config.TransferLearningConfig), - default=get_default("transfer_selection_metric", config.TransferLearningConfig), + type=get_type("transfer_selection_metric", TransferLearningConfig), + default=get_default("transfer_selection_metric", TransferLearningConfig), help="Metric used to select the model for transfer learning in the MAPS defined by transfer_path.", show_default=True, ) diff --git a/clinicadl/config/options/transforms.py b/clinicadl/config/options/transforms.py index 15793f4eb..8e886d7e0 100644 --- a/clinicadl/config/options/transforms.py +++ b/clinicadl/config/options/transforms.py @@ -1,7 +1,6 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.transforms import TransformsConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type @@ -9,15 +8,15 @@ data_augmentation = click.option( "--data_augmentation", "-da", - type=get_type("data_augmentation", config.TransformsConfig), - default=get_default("data_augmentation", config.TransformsConfig), + type=get_type("data_augmentation", TransformsConfig), + default=get_default("data_augmentation", TransformsConfig), multiple=True, help="Randomly applies transforms on the training set.", show_default=True, ) normalize = click.option( "--normalize/--unnormalize", - default=get_default("normalize", config.TransformsConfig), + default=get_default("normalize", TransformsConfig), help="Disable default MinMaxNormalization.", show_default=True, ) diff --git a/clinicadl/config/options/validation.py b/clinicadl/config/options/validation.py index 235eca52a..0fc64abd8 100644 --- a/clinicadl/config/options/validation.py +++ b/clinicadl/config/options/validation.py @@ -1,22 +1,21 @@ import click -import clinicadl.train.trainer.training_config as config -from clinicadl.config import config +from clinicadl.config.config.validation import ValidationConfig from clinicadl.utils.config_utils import get_default_from_config_class as get_default from clinicadl.utils.config_utils import get_type_from_config_class as get_type # Validation valid_longitudinal = click.option( "--valid_longitudinal/--valid_baseline", - default=get_default("valid_longitudinal", config.ValidationConfig), + default=get_default("valid_longitudinal", ValidationConfig), help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", show_default=True, ) evaluation_steps = click.option( "--evaluation_steps", "-esteps", - type=get_type("evaluation_steps", config.ValidationConfig), - default=get_default("evaluation_steps", config.ValidationConfig), + type=get_type("evaluation_steps", ValidationConfig), + default=get_default("evaluation_steps", ValidationConfig), help="Fix the number of iterations to perform before computing an evaluation. Default will only " "perform one evaluation at the end of each epoch.", show_default=True, @@ -25,8 +24,8 @@ selection_metrics = click.option( "--selection_metrics", "-sm", - type=get_type("selection_metrics", config.ValidationConfig), # str list ? - default=get_default("selection_metrics", config.ValidationConfig), # ["loss"] + type=get_type("selection_metrics", ValidationConfig), # str list ? + default=get_default("selection_metrics", ValidationConfig), # ["loss"] multiple=True, help="""Allow to select a list of models based on their selection metric. Default will only infer the result of the best model selected on loss.""", diff --git a/clinicadl/interpret/interpret_cli.py b/clinicadl/interpret/interpret_cli.py index 6f137224f..fce61533a 100644 --- a/clinicadl/interpret/interpret_cli.py +++ b/clinicadl/interpret/interpret_cli.py @@ -1,63 +1,47 @@ import click -from clinicadl.interpret import interpret_param -from clinicadl.predict.predict_config import InterpretConfig +from clinicadl.config import arguments +from clinicadl.config.config.pipelines.interpret import InterpretConfig +from clinicadl.config.options import ( + computational, + data, + dataloader, + interpret, + maps_manager, + validation, +) from clinicadl.predict.predict_manager import PredictManager -from clinicadl.utils.exceptions import ClinicaDLArgumentError - -config = InterpretConfig.model_fields @click.command("interpret", no_args_is_help=True) -@interpret_param.input_maps -@interpret_param.data_group -@interpret_param.name -@interpret_param.method -@interpret_param.level -@interpret_param.selection_metrics -@interpret_param.participants_list -@interpret_param.caps_directory -@interpret_param.multi_cohort -@interpret_param.diagnoses -@interpret_param.target_node -@interpret_param.save_individual -@interpret_param.n_proc -@interpret_param.gpu -@interpret_param.amp -@interpret_param.batch_size -@interpret_param.overwrite -@interpret_param.overwrite_name -@interpret_param.save_nifti -def cli(input_maps_directory, data_group, name, method, **kwargs): +@arguments.input_maps +@arguments.data_group +@maps_manager.overwrite +@maps_manager.save_nifti +@interpret.name +@interpret.method +@interpret.level +@interpret.target_node +@interpret.save_individual +@interpret.overwrite_name +@data.participants_tsv +@data.caps_directory +@data.multi_cohort +@data.diagnoses +@dataloader.n_proc +@dataloader.batch_size +@computational.gpu +@computational.amp +@validation.selection_metrics +def cli(**kwargs): """Interpretation of trained models using saliency map method. - INPUT_MAPS_DIRECTORY is the MAPS folder from where the model to interpret will be loaded. - DATA_GROUP is the name of the subjects and sessions list used for the interpretation. - NAME is the name of the saliency map task. - METHOD is the method used to extract an attribution map. """ - from clinicadl.utils.cmdline_utils import check_gpu - - if kwargs["gpu"]: - check_gpu() - elif kwargs["amp"]: - raise ClinicaDLArgumentError( - "AMP is designed to work with modern GPUs. Please add the --gpu flag." - ) - - interpret_config = InterpretConfig( - maps_dir=input_maps_directory, - data_group=data_group, - name=name, - method_cls=method, - tsv_path=kwargs["participants_tsv"], - level=kwargs["level_grad_cam"], - **kwargs, - ) + interpret_config = InterpretConfig(**kwargs) predict_manager = PredictManager(interpret_config) predict_manager.interpret() diff --git a/clinicadl/interpret/interpret_param.py b/clinicadl/interpret/interpret_param.py deleted file mode 100644 index 8d932c03b..000000000 --- a/clinicadl/interpret/interpret_param.py +++ /dev/null @@ -1,125 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl.predict.predict_config import InterpretationMethod, InterpretConfig - -config = InterpretConfig.model_fields - -input_maps = click.argument("input_maps_directory", type=config["maps_dir"].annotation) -data_group = click.argument("data_group", type=config["data_group"].annotation) -selection_metrics = click.option( - "--selection_metrics", - "-sm", - type=get_args(config["selection_metrics"].annotation)[0], # str list ? - default=config["selection_metrics"].default, # ["loss"] - multiple=True, - help="""Allow to select a list of models based on their selection metric. Default will - only infer the result of the best model selected on loss.""", - show_default=True, -) -participants_list = click.option( - "--participants_tsv", - type=get_args(config["tsv_path"].annotation)[0], # Path - default=config["tsv_path"].default, # None - help="""Path to the file with subjects/sessions to process, if different from the one used during network training. - If it includes the filename will load the TSV file directly. - Else will load the baseline TSV files of wanted diagnoses produced by `tsvtool split`.""", - show_default=True, -) -caps_directory = click.option( - "--caps_directory", - type=get_args(config["caps_directory"].annotation)[0], # Path - default=config["caps_directory"].default, # None - help="Data using CAPS structure, if different from the one used during network training.", - show_default=True, -) -multi_cohort = click.option( - "--multi_cohort", - is_flag=True, - help="Performs multi-cohort interpretation. In this case, caps_directory and tsv_path must be paths to TSV files.", -) -diagnoses = click.option( - "--diagnoses", - "-d", - type=get_args(config["diagnoses"].annotation)[0], # str list ? - default=config["diagnoses"].default, # ?? - multiple=True, - help="List of diagnoses used for inference. Is used only if PARTICIPANTS_TSV leads to a folder.", - show_default=True, -) -n_proc = click.option( - "-np", - "--n_proc", - type=config["n_proc"].annotation, - default=config["n_proc"].default, - show_default=True, - help="Number of cores used during the task.", -) -gpu = click.option( - "--gpu/--no-gpu", - show_default=True, - default=config["gpu"].default, - help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", -) -batch_size = click.option( - "--batch_size", - type=config["batch_size"].annotation, # int - default=config["batch_size"].default, # 8 - show_default=True, - help="Batch size for data loading.", -) -amp = click.option( - "--amp/--no-amp", - default=config["amp"].default, # false - help="Enables automatic mixed precision during training and inference.", - show_default=True, -) -overwrite = click.option( - "--overwrite", - "-o", - is_flag=True, - help="Will overwrite data group if existing. Please give caps_directory and participants_tsv to" - " define new data group.", -) -save_nifti = click.option( - "--save_nifti", - is_flag=True, - help="Save the output map(s) in the MAPS in NIfTI format.", -) - -# interpret specific -name = click.argument( - "name", - type=config["name"].annotation, -) -method = click.argument( - "method", - type=click.Choice(InterpretationMethod), # ["gradients", "grad-cam"] -) -level = click.option( - "--level_grad_cam", - type=get_args(config["level"].annotation)[0], - default=config["level"].default, - help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", - show_default=True, -) -target_node = click.option( - "--target_node", - type=config["target_node"].annotation, # int - default=config["target_node"].default, # 0 - help="Which target node the gradients explain. Default takes the first output node.", - show_default=True, -) -save_individual = click.option( - "--save_individual", - is_flag=True, - help="Save individual saliency maps in addition to the mean saliency map.", -) -overwrite_name = click.option( - "--overwrite_name", - "-on", - is_flag=True, - help="Overwrite the name if it already exists.", -) diff --git a/clinicadl/predict/predict_cli.py b/clinicadl/predict/predict_cli.py index c9a455ec4..411455803 100644 --- a/clinicadl/predict/predict_cli.py +++ b/clinicadl/predict/predict_cli.py @@ -1,38 +1,42 @@ import click -from clinicadl.predict import predict_param -from clinicadl.predict.predict_config import PredictConfig +from clinicadl.config import arguments +from clinicadl.config.config.pipelines.predict import PredictConfig +from clinicadl.config.options import ( + computational, + cross_validation, + data, + dataloader, + maps_manager, + predict, + validation, +) from clinicadl.predict.predict_manager import PredictManager -from clinicadl.utils.cmdline_utils import check_gpu -from clinicadl.utils.exceptions import ClinicaDLArgumentError - -config = PredictConfig.model_fields @click.command(name="predict", no_args_is_help=True) -@predict_param.input_maps -@predict_param.data_group -@predict_param.caps_directory -@predict_param.participants_list -@predict_param.use_labels -@predict_param.multi_cohort -@predict_param.diagnoses -@predict_param.label -@predict_param.save_tensor -@predict_param.save_nifti -@predict_param.save_latent_tensor -@predict_param.skip_leak_check -@predict_param.split -@predict_param.selection_metrics -@predict_param.gpu -@predict_param.amp -@predict_param.n_proc -@predict_param.batch_size -@predict_param.overwrite +@arguments.input_maps +@arguments.data_group +@maps_manager.save_nifti +@maps_manager.overwrite +@predict.use_labels +@data.label +@predict.save_tensor +@predict.save_latent_tensor +@data.caps_directory +@data.participants_tsv +@data.multi_cohort +@data.diagnoses +@validation.skip_leak_check +@validation.selection_metrics +@cross_validation.split +@computational.gpu +@computational.amp +@dataloader.n_proc +@dataloader.batch_size def cli(input_maps_directory, data_group, **kwargs): """This function loads a MAPS and predicts the global metrics and individual values for all the models selected using a metric in selection_metrics. - Args: maps_dir: path to the MAPS. data_group: name of the data group tested. @@ -51,29 +55,12 @@ def cli(input_maps_directory, data_group, **kwargs): overwrite: If True former definition of data group is erased save_tensor: For reconstruction task only, if True it will save the reconstruction as .pt file in the MAPS. save_nifti: For reconstruction task only, if True it will save the reconstruction as NIfTI file in the MAPS. - Infer the outputs of a trained model on a test set. - INPUT_MAPS_DIRECTORY is the MAPS folder from where the model used for prediction will be loaded. - DATA_GROUP is the name of the subjects and sessions list used for the interpretation. """ - if kwargs["gpu"]: - check_gpu() - elif kwargs["amp"]: - raise ClinicaDLArgumentError( - "AMP is designed to work with modern GPUs. Please add the --gpu flag." - ) - - predict_config = PredictConfig( - maps_dir=input_maps_directory, - data_group=data_group, - tsv_path=kwargs["participants_tsv"], - split_list=kwargs["split"], - **kwargs, - ) - + predict_config = PredictConfig(**kwargs) predict_manager = PredictManager(predict_config) predict_manager.predict() diff --git a/clinicadl/predict/predict_config.py b/clinicadl/predict/predict_config.py deleted file mode 100644 index 800b932cb..000000000 --- a/clinicadl/predict/predict_config.py +++ /dev/null @@ -1,126 +0,0 @@ -from enum import Enum -from logging import getLogger -from pathlib import Path -from typing import Dict, Optional, Union - -from pydantic import BaseModel, field_validator - -from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp -from clinicadl.utils.caps_dataset.data import ( - load_data_test, -) -from clinicadl.utils.enum import InterpretationMethod -from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore -from clinicadl.utils.maps_manager.maps_manager import MapsManager # type: ignore - -logger = getLogger("clinicadl.predict_config") - - -class PredictInterpretConfig(BaseModel): - maps_dir: Path - data_group: str - caps_directory: Optional[Path] = None - tsv_path: Optional[Path] = None - selection_metrics: list[str] = ["loss"] - split_list: list[int] = [] - diagnoses: list[str] = ["AD", "CN"] - multi_cohort: bool = False - batch_size: int = 8 - n_proc: int = 1 - gpu: bool = True - amp: bool = False - overwrite: bool = False - save_nifti: bool = False - skip_leak_check: bool = False - - @field_validator("selection_metrics", "split_list", "diagnoses", mode="before") - def list_to_tuples(cls, v): - if isinstance(v, list): - return tuple(v) - return v - - def adapt_config_with_maps_manager_info(self, maps_manager: MapsManager): - if not self.split_list: - self.split_list = maps_manager._find_splits() - logger.debug(f"List of splits {self.split_list}") - - if self.diagnoses is None or len(self.diagnoses) == 0: - self.diagnoses = maps_manager.diagnoses - - if not self.batch_size: - self.batch_size = maps_manager.batch_size - - if not self.n_proc: - self.n_proc = maps_manager.n_proc - - def create_groupe_df(self): - group_df = None - if self.tsv_path is not None and self.tsv_path.is_file(): - group_df = load_data_test( - self.tsv_path, - self.diagnoses, - multi_cohort=self.multi_cohort, - ) - return group_df - - -class InterpretConfig(PredictInterpretConfig): - name: str - method_cls: InterpretationMethod = InterpretationMethod.GRADIENTS - target_node: int = 0 - save_individual: bool = False - overwrite_name: bool = False - level: Optional[int] = 1 - - @field_validator("level", mode="before") - def chek_level(cls, v): - if v < 1: - raise ValueError( - f"You must set the level to a number bigger than 1. ({v} < 1)" - ) - - @property - def method(self) -> InterpretationMethod: - return self.method_cls - - @method.setter - def method(self, value: Union[str, InterpretationMethod]): - self.method_cls = InterpretationMethod(value) - - def get_method(self) -> Gradients: - if self.method == InterpretationMethod.GRADIENTS: - return VanillaBackProp - elif self.method == InterpretationMethod.GRAD_CAM: - return GradCam - else: - raise ValueError(f"The method {self.method.value} is not implemented") - - -class PredictConfig(PredictInterpretConfig): - label: str = "" - save_tensor: bool = False - save_latent_tensor: bool = False - use_labels: bool = True - - def check_output_saving(self, network_task: str) -> None: - # Check if task is reconstruction for "save_tensor" and "save_nifti" - if self.save_tensor and network_task != "reconstruction": - raise ClinicaDLArgumentError( - "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." - ) - if self.save_nifti and network_task != "reconstruction": - raise ClinicaDLArgumentError( - "Cannot save nifti if the network task is not reconstruction. Please remove --save_nifti option." - ) - - def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): - return ( - self.label is not None - and self.label != "" - and self.label != _label - and _label_code == "default" - ) - - def check_label(self, _label: str): - if not self.label: - self.label = _label diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index 866d9a5dc..2e570d058 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -11,10 +11,8 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from clinicadl.predict.predict_config import ( - InterpretConfig, - PredictConfig, -) +from clinicadl.config.config.pipelines.interpret import InterpretConfig +from clinicadl.config.config.pipelines.predict import PredictConfig from clinicadl.utils.caps_dataset.data import ( return_dataset, ) @@ -41,7 +39,6 @@ def predict( label_code: Union[str, dict[str, int]] = "default", ): """Performs the prediction task on a subset of caps_directory defined in a TSV file. - Parameters ---------- data_group : str @@ -87,7 +84,6 @@ def predict( If true, save the tensor from the latent space for reconstruction task. skip_leak_check : bool (optional, default=False) If true, skip the leak check (not recommended). - Examples -------- >>> _input_ @@ -96,8 +92,11 @@ def predict( assert isinstance(self._config, PredictConfig) - self._config.check_output_saving(self.maps_manager.network_task) - self._config.adapt_config_with_maps_manager_info(self.maps_manager) + self._config.check_output_saving_nifti(self.maps_manager.network_task) + self._config.adapt_data_with_maps_manager_info(self.maps_manager) + self._config.adapt_dataloader_with_maps_manager_info(self.maps_manager) + self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) + self._config.check_output_saving_tensor(self.maps_manager.network_task) _, all_transforms = get_transforms( normalize=self.maps_manager.normalize, @@ -105,18 +104,15 @@ def predict( size_reduction=self.maps_manager.size_reduction, size_reduction_factor=self.maps_manager.size_reduction_factor, ) - group_df = self._config.create_groupe_df() self._check_data_group(group_df) criterion = self.maps_manager.task_manager.get_criterion(self.maps_manager.loss) self._check_data_group(df=group_df) - assert ( - self._config.split_list - ) # don't know if needed ? try to raise an exception ? + assert self._config.split # don't know if needed ? try to raise an exception ? # assert self._config.label - for split in self._config.split_list: + for split in self._config.split: logger.info(f"Prediction of split {split}") group_df, group_parameters = self.get_group_info( self._config.data_group, split @@ -126,7 +122,6 @@ def predict( self.maps_manager.task_manager.generate_label_code( group_df, self._config.label ) - # Erase previous TSV files on master process if not self._config.selection_metrics: split_selection_metrics = self.maps_manager._find_selection_metrics( @@ -141,14 +136,10 @@ def predict( / f"best-{selection}" / self._config.data_group ) - tsv_pattern = f"{self._config.data_group}*.tsv" - for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() - self._config.check_label(self.maps_manager.label) - if self.maps_manager.multi_network: self._predict_multi( group_parameters, @@ -159,7 +150,6 @@ def predict( split, split_selection_metrics, ) - else: self._predict_single( group_parameters, @@ -170,7 +160,6 @@ def predict( split, split_selection_metrics, ) - if cluster.master: self.maps_manager._ensemble_prediction( self._config.data_group, @@ -191,7 +180,6 @@ def _predict_multi( split_selection_metrics, ): """_summary_ - Parameters ---------- group_parameters : _type_ @@ -230,16 +218,13 @@ def _predict_multi( _description_ selection_metrics : _type_ _description_ - Examples -------- >>> _input_ _output_ - Notes ----- _notes_ - See Also -------- - _related_ @@ -326,7 +311,6 @@ def _predict_single( split_selection_metrics, ): """_summary_ - Parameters ---------- group_parameters : _type_ @@ -365,16 +349,13 @@ def _predict_single( _description_ selection_metrics : _type_ _description_ - Examples -------- >>> _input_ _output_ - Notes ----- _notes_ - See Also -------- - _related_ @@ -395,7 +376,6 @@ def _predict_single( self.maps_manager.label_code if label_code == "default" else label_code ), ) - test_loader = DataLoader( data_test, batch_size=( @@ -453,7 +433,6 @@ def _compute_latent_tensors( ): """ Compute the output tensors and saves them in the MAPS. - Parameters ---------- dataset : _type_ @@ -487,7 +466,6 @@ def _compute_latent_tensors( amp=self.maps_manager.amp, ) model.eval() - tensor_path = ( self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" @@ -498,12 +476,10 @@ def _compute_latent_tensors( if cluster.master: tensor_path.mkdir(parents=True, exist_ok=True) dist.barrier() - if nb_images is None: # Compute outputs for the whole data set nb_modes = len(dataset) else: nb_modes = nb_images * dataset.elem_per_image - for i in [ *range(cluster.rank, nb_modes, cluster.world_size), *range(int(nb_modes % cluster.world_size <= cluster.rank)), @@ -530,7 +506,6 @@ def _compute_output_nifti( network: Optional[int] = None, ): """Computes the output nifti images and saves them in the MAPS. - Parameters ---------- dataset : _type_ @@ -545,11 +520,9 @@ def _compute_output_nifti( If given, a new value for the device of the model will be computed. network : int (optional, default=None) Index of the network tested (only used in multi-network setting). - Raises -------- ClinicaDLException if not an image - """ import nibabel as nib from numpy import eye @@ -570,7 +543,6 @@ def _compute_output_nifti( amp=self.maps_manager.amp, ) model.eval() - nifti_path = ( self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" @@ -581,7 +553,6 @@ def _compute_output_nifti( if cluster.master: nifti_path.mkdir(parents=True, exist_ok=True) dist.barrier() - nb_imgs = len(dataset) for i in [ *range(cluster.rank, nb_imgs, cluster.world_size), @@ -607,7 +578,6 @@ def _compute_output_nifti( def interpret(self): """Performs the interpretation task on a subset of caps_directory defined in a TSV file. The mean interpretation is always saved, to save the individual interpretations set save_individual to True. - Parameters ---------- data_group : str @@ -653,7 +623,6 @@ def interpret(self): Layer number in the convolutional part after which the feature map is chosen. save_nifti : bool (optional, default=False) If True, save the interpretation map in nifti format. - Raises ------ NotImplementedError @@ -662,34 +631,32 @@ def interpret(self): If the interpretaion of multi network is asked MAPSError If the interpretation has already been determined. - """ assert isinstance(self._config, InterpretConfig) - self._config.adapt_config_with_maps_manager_info(self.maps_manager) + self._config.adapt_data_with_maps_manager_info(self.maps_manager) + self._config.adapt_dataloader_with_maps_manager_info(self.maps_manager) + self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) if self.maps_manager.multi_network: raise NotImplementedError( "The interpretation of multi-network framework is not implemented." ) - _, all_transforms = get_transforms( normalize=self.maps_manager.normalize, data_augmentation=self.maps_manager.data_augmentation, size_reduction=self.maps_manager.size_reduction, size_reduction_factor=self.maps_manager.size_reduction_factor, ) - group_df = self._config.create_groupe_df() self._check_data_group(group_df) - assert self._config.split_list - for split in self._config.split_list: + assert self._config.split + for split in self._config.split: logger.info(f"Interpretation of split {split}") df_group, parameters_group = self.get_group_info( self._config.data_group, split ) - data_test = return_dataset( parameters_group["caps_directory"], df_group, @@ -700,19 +667,16 @@ def interpret(self): label_code=self.maps_manager.label_code, label=self.maps_manager.label, ) - test_loader = DataLoader( data_test, batch_size=self._config.batch_size, shuffle=False, num_workers=self._config.n_proc, ) - if not self._config.selection_metrics: self._config.selection_metrics = ( self.maps_manager._find_selection_metrics(split) ) - for selection_metric in self._config.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( @@ -722,7 +686,6 @@ def interpret(self): / self._config.data_group / f"interpret-{self._config.name}" ) - if (results_path).is_dir(): if self._config.overwrite_name: shutil.rmtree(results_path) @@ -732,20 +695,16 @@ def interpret(self): f"Please choose another name or set overwrite_name to True." ) results_path.mkdir(parents=True) - model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, gpu=self._config.gpu, ) - interpreter = self._config.get_method()(model) - cum_maps = [0] * data_test.elem_per_image for data in test_loader: images = data["image"].to(model.device) - map_pt = interpreter.generate_gradients( images, self._config.target_node, @@ -769,13 +728,10 @@ def interpret(self): results_path / f"{data['participant_id'][i]}_{data['session_id'][i]}_{self.maps_manager.mode}-{data[f'{self.maps_manager.mode}_id'][i]}_map.nii.gz" ) - output_nii = nib.Nifti1Image(map_pt[i].numpy(), eye(4)) nib.save(output_nii, single_nifti_path) - for i, mode_map in enumerate(cum_maps): mode_map /= len(data_test) - torch.save( mode_map, results_path / f"mean_{self.maps_manager.mode}-{i}_map.pt", @@ -836,8 +792,8 @@ def _check_data_group( else: # if not split_list: # split_list = self.maps_manager._find_splits() - assert self._config.split_list - for split in self._config.split_list: + assert self._config.split + for split in self._config.split: selection_metrics = self.maps_manager._find_selection_metrics( split ) diff --git a/clinicadl/predict/predict_param.py b/clinicadl/predict/predict_param.py deleted file mode 100644 index 15fc07b90..000000000 --- a/clinicadl/predict/predict_param.py +++ /dev/null @@ -1,132 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click - -from clinicadl import MapsManager -from clinicadl.predict.predict_config import PredictConfig - -config = PredictConfig.model_fields - -input_maps = click.argument("input_maps_directory", type=config["maps_dir"].annotation) -data_group = click.argument("data_group", type=config["data_group"].annotation) -participants_list = click.option( - "--participants_tsv", - type=get_args(config["tsv_path"].annotation)[0], # Path - default=config["tsv_path"].default, # None - help="""Path to the file with subjects/sessions to process, if different from the one used during network training. - If it includes the filename will load the TSV file directly. - Else will load the baseline TSV files of wanted diagnoses produced by `tsvtool split`.""", - show_default=True, -) -caps_directory = click.option( - "--caps_directory", - type=get_args(config["caps_directory"].annotation)[0], # Path - default=config["caps_directory"].default, # None - help="Data using CAPS structure, if different from the one used during network training.", - show_default=True, -) -multi_cohort = click.option( - "--multi_cohort", - is_flag=True, - help="Performs multi-cohort interpretation. In this case, caps_directory and tsv_path must be paths to TSV files.", -) -diagnoses = click.option( - "--diagnoses", - "-d", - type=get_args(config["diagnoses"].annotation)[0], # str list ? - default=config["diagnoses"].default, # ?? - multiple=True, - help="List of diagnoses used for inference. Is used only if PARTICIPANTS_TSV leads to a folder.", - show_default=True, -) -save_nifti = click.option( - "--save_nifti", - is_flag=True, - help="Save the output map(s) in the MAPS in NIfTI format.", -) -selection_metrics = click.option( - "--selection_metrics", - "-sm", - type=get_args(config["selection_metrics"].annotation)[0], # str list ? - default=config["selection_metrics"].default, # ["loss"] - multiple=True, - help="""Allow to select a list of models based on their selection metric. Default will - only infer the result of the best model selected on loss.""", - show_default=True, -) -n_proc = click.option( - "-np", - "--n_proc", - type=config["n_proc"].annotation, - default=config["n_proc"].default, - show_default=True, - help="Number of cores used during the task.", -) -gpu = click.option( - "--gpu/--no-gpu", - show_default=True, - default=config["gpu"].default, - help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", -) -batch_size = click.option( - "--batch_size", - type=config["batch_size"].annotation, # int - default=config["batch_size"].default, # 8 - show_default=True, - help="Batch size for data loading.", -) -amp = click.option( - "--amp/--no-amp", - default=config["amp"].default, # false - help="Enables automatic mixed precision during training and inference.", - show_default=True, -) -overwrite = click.option( - "--overwrite", - "-o", - is_flag=True, - help="Will overwrite data group if existing. Please give caps_directory and participants_tsv to" - " define new data group.", -) - - -# predict specific -use_labels = click.option( - "--use_labels/--no_labels", - show_default=True, - default=config["use_labels"].default, # false - help="Set this option to --no_labels if your dataset does not contain ground truth labels.", -) -label = click.option( - "--label", - type=config["label"].annotation, # str - default=config["label"].default, # None - show_default=True, - help="Target label used for training (if NETWORK_TASK in [`regression`, `classification`]). " - "Default will reuse the same label as during the training task.", -) -save_tensor = click.option( - "--save_tensor", - is_flag=True, - help="Save the reconstruction output in the MAPS in Pytorch tensor format.", -) -save_latent_tensor = click.option( - "--save_latent_tensor", - is_flag=True, - help="""Save the latent representation of the image.""", -) -skip_leak_check = click.option( - "--skip_leak_check", - is_flag=True, - help="Skip the data leakage check.", -) -split = click.option( - "--split", - "-s", - type=get_args(config["split_list"].annotation)[0], # list[int] - default=config["split_list"].default, # [] ? - multiple=True, - show_default=True, - help="Make inference on the list of given splits. By default, inference is done on all the splits.", -) diff --git a/clinicadl/random_search/random_search_config.py b/clinicadl/random_search/random_search_config.py index 0c11cfca4..cbf304c64 100644 --- a/clinicadl/random_search/random_search_config.py +++ b/clinicadl/random_search/random_search_config.py @@ -4,17 +4,17 @@ from pydantic import BaseModel, ConfigDict, PositiveInt, field_validator -from clinicadl.train.tasks.classification.config import ( +from clinicadl.config.config.pipelines.task.classification import ( ClassificationConfig as BaseClassificationConfig, ) -from clinicadl.train.tasks.regression.config import ( +from clinicadl.config.config.pipelines.task.regression import ( RegressionConfig as BaseRegressionConfig, ) from clinicadl.utils.config_utils import get_type_from_config_class as get_type from clinicadl.utils.enum import Normalization, Pooling, Task if TYPE_CHECKING: - from clinicadl.train.trainer import TrainingConfig + from clinicadl.train.trainer import TrainConfig class RandomSearchConfig( @@ -76,7 +76,7 @@ def architecture_validator(cls, v): v == "RandomArchitecture" ), "Only RandomArchitecture can be used in Random Search." - class TrainingConfig(base_training_config): + class TrainConfig(base_training_config): """ Config class for the training of a random model. @@ -91,7 +91,7 @@ class TrainingConfig(base_training_config): model: ModelConfig - return TrainingConfig + return TrainConfig @training_config_for_random_models @@ -104,7 +104,7 @@ class RegressionConfig(BaseRegressionConfig): pass -def create_training_config(task: Union[str, Task]) -> Type[TrainingConfig]: +def create_training_config(task: Union[str, Task]) -> Type[TrainConfig]: """ A factory function to create a Training Config class suited for the task, in Random Search mode. @@ -115,7 +115,7 @@ def create_training_config(task: Union[str, Task]) -> Type[TrainingConfig]: Returns ------- - Type[TrainingConfig] + Type[TrainConfig] The Config class. """ task = Task(task) diff --git a/clinicadl/train/tasks/classification/classification_cli.py b/clinicadl/train/tasks/classification_cli.py similarity index 97% rename from clinicadl/train/tasks/classification/classification_cli.py rename to clinicadl/train/tasks/classification_cli.py index 1f69daa7e..be21f690f 100644 --- a/clinicadl/train/tasks/classification/classification_cli.py +++ b/clinicadl/train/tasks/classification_cli.py @@ -1,6 +1,7 @@ import click from clinicadl.config import arguments +from clinicadl.config.config.pipelines.task.classification import ClassificationConfig from clinicadl.config.options import ( callbacks, computational, @@ -19,7 +20,6 @@ transforms, validation, ) -from clinicadl.train.tasks.classification.config import ClassificationConfig from clinicadl.train.trainer import Trainer from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.utils.enum import Task diff --git a/clinicadl/train/tasks/reconstruction/reconstruction_cli.py b/clinicadl/train/tasks/reconstruction_cli.py similarity index 97% rename from clinicadl/train/tasks/reconstruction/reconstruction_cli.py rename to clinicadl/train/tasks/reconstruction_cli.py index edaaa4510..e2b37b10e 100644 --- a/clinicadl/train/tasks/reconstruction/reconstruction_cli.py +++ b/clinicadl/train/tasks/reconstruction_cli.py @@ -1,6 +1,7 @@ import click from clinicadl.config import arguments +from clinicadl.config.config.pipelines.task.reconstruction import ReconstructionConfig from clinicadl.config.options import ( callbacks, computational, @@ -19,7 +20,6 @@ transforms, validation, ) -from clinicadl.train.tasks.reconstruction.config import ReconstructionConfig from clinicadl.train.trainer import Trainer from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.utils.enum import Task diff --git a/clinicadl/train/tasks/regression/regression_cli.py b/clinicadl/train/tasks/regression_cli.py similarity index 97% rename from clinicadl/train/tasks/regression/regression_cli.py rename to clinicadl/train/tasks/regression_cli.py index cc398d062..a3252424b 100644 --- a/clinicadl/train/tasks/regression/regression_cli.py +++ b/clinicadl/train/tasks/regression_cli.py @@ -1,6 +1,7 @@ import click from clinicadl.config import arguments +from clinicadl.config.config.pipelines.task.regression import RegressionConfig from clinicadl.config.options import ( callbacks, computational, @@ -19,7 +20,6 @@ transforms, validation, ) -from clinicadl.train.tasks.regression.config import RegressionConfig from clinicadl.train.trainer import Trainer from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.utils.enum import Task diff --git a/clinicadl/train/tasks/tasks_utils.py b/clinicadl/train/tasks/tasks_utils.py index 40d15bfd0..f05c3d5e5 100644 --- a/clinicadl/train/tasks/tasks_utils.py +++ b/clinicadl/train/tasks/tasks_utils.py @@ -1,10 +1,10 @@ from typing import Type, Union -from clinicadl.train.trainer import TrainingConfig +from clinicadl.config.config.pipelines.train import TrainConfig from clinicadl.utils.enum import Task -def create_training_config(task: Union[str, Task]) -> Type[TrainingConfig]: +def create_training_config(task: Union[str, Task]) -> Type[TrainConfig]: """ A factory function to create a Training Config class suited for the task. Parameters @@ -15,13 +15,15 @@ def create_training_config(task: Union[str, Task]) -> Type[TrainingConfig]: """ task = Task(task) if task == Task.CLASSIFICATION: - from clinicadl.train.tasks.classification.config import ( + from clinicadl.config.config.pipelines.task.classification import ( ClassificationConfig as Config, ) elif task == Task.REGRESSION: - from clinicadl.train.tasks.regression.config import RegressionConfig as Config + from clinicadl.config.config.pipelines.task.regression import ( + RegressionConfig as Config, + ) elif task == Task.RECONSTRUCTION: - from clinicadl.train.tasks.reconstruction.config import ( + from clinicadl.config.config.pipelines.task.reconstruction import ( ReconstructionConfig as Config, ) return Config diff --git a/clinicadl/train/tasks/train_task_cli_options.py b/clinicadl/train/tasks/train_task_cli_options.py deleted file mode 100644 index d0dfd0d51..000000000 --- a/clinicadl/train/tasks/train_task_cli_options.py +++ /dev/null @@ -1,334 +0,0 @@ -import click - -import clinicadl.train.trainer.training_config as config -from clinicadl.utils import cli_param -from clinicadl.utils.config_utils import get_default_from_config_class as get_default -from clinicadl.utils.config_utils import get_type_from_config_class as get_type - -# Arguments -caps_directory = cli_param.argument.caps_directory -preprocessing_json = cli_param.argument.preprocessing_json -tsv_directory = click.argument( - "tsv_directory", - type=click.Path(exists=True), -) -output_maps = cli_param.argument.output_maps - -# Config file -config_file = click.option( - "--config_file", - "-c", - type=click.Path(exists=True), - help="Path to the TOML or JSON file containing the values of the options needed for training.", -) - -# Callbacks -emissions_calculator = cli_param.option_group.informations_group.option( - "--calculate_emissions/--dont_calculate_emissions", - default=get_default("emissions_calculator", config.CallbacksConfig), - help="Flag to allow calculate the carbon emissions during training.", - show_default=True, -) -track_exp = cli_param.option_group.optimization_group.option( - "--track_exp", - "-te", - type=click.Choice(get_type("track_exp", config.CallbacksConfig)), - default=get_default("track_exp", config.CallbacksConfig), - help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.", - show_default=True, -) -# Computational -amp = cli_param.option_group.computational_group.option( - "--amp/--no-amp", - default=get_default("amp", config.ComputationalConfig), - help="Enables automatic mixed precision during training and inference.", - show_default=True, -) -fully_sharded_data_parallel = cli_param.option_group.computational_group.option( - "--fully_sharded_data_parallel", - "-fsdp", - is_flag=True, - help="Enables Fully Sharded Data Parallel with Pytorch to save memory at the cost of communications. " - "Currently this only enables ZeRO Stage 1 but will be entirely replaced by FSDP in a later patch, " - "this flag is already set to FSDP to that the zero flag is never actually removed.", -) -gpu = cli_param.option_group.computational_group.option( - "--gpu/--no-gpu", - default=get_default("gpu", config.ComputationalConfig), - help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", - show_default=True, -) -# Cross Validation -n_splits = cli_param.option_group.cross_validation.option( - "--n_splits", - type=get_type("n_splits", config.CrossValidationConfig), - default=get_default("n_splits", config.CrossValidationConfig), - help="If a value is given for k will load data of a k-fold CV. " - "Default value (0) will load a single split.", - show_default=True, -) -split = cli_param.option_group.cross_validation.option( - "--split", - "-s", - type=get_type("split", config.CrossValidationConfig), - default=get_default("split", config.CrossValidationConfig), - multiple=True, - help="Train the list of given splits. By default, all the splits are trained.", - show_default=True, -) -# Data -baseline = cli_param.option_group.data_group.option( - "--baseline/--longitudinal", - default=get_default("baseline", config.DataConfig), - help="If provided, only the baseline sessions are used for training.", - show_default=True, -) -diagnoses = cli_param.option_group.data_group.option( - "--diagnoses", - "-d", - type=get_type("diagnoses", config.DataConfig), - default=get_default("diagnoses", config.DataConfig), - multiple=True, - help="List of diagnoses used for training.", - show_default=True, -) -multi_cohort = cli_param.option_group.data_group.option( - "--multi_cohort/--single_cohort", - default=get_default("multi_cohort", config.DataConfig), - help="Performs multi-cohort training. In this case, caps_dir and tsv_path must be paths to TSV files.", - show_default=True, -) -# DataLoader -batch_size = cli_param.option_group.computational_group.option( - "--batch_size", - type=get_type("batch_size", config.DataLoaderConfig), - default=get_default("batch_size", config.DataLoaderConfig), - help="Batch size for data loading.", - show_default=True, -) -n_proc = cli_param.option_group.computational_group.option( - "-np", - "--n_proc", - type=get_type("n_proc", config.DataLoaderConfig), - default=get_default("n_proc", config.DataLoaderConfig), - help="Number of cores used during the task.", - show_default=True, -) -sampler = cli_param.option_group.data_group.option( - "--sampler", - "-s", - type=click.Choice(get_type("sampler", config.DataLoaderConfig)), - default=get_default("sampler", config.DataLoaderConfig), - help="Sampler used to load the training data set.", - show_default=True, -) -# Early Stopping -patience = cli_param.option_group.optimization_group.option( - "--patience", - type=get_type("patience", config.EarlyStoppingConfig), - default=get_default("patience", config.EarlyStoppingConfig), - help="Number of epochs for early stopping patience.", - show_default=True, -) -tolerance = cli_param.option_group.optimization_group.option( - "--tolerance", - type=get_type("tolerance", config.EarlyStoppingConfig), - default=get_default("tolerance", config.EarlyStoppingConfig), - help="Value for early stopping tolerance.", - show_default=True, -) -# LR scheduler -adaptive_learning_rate = cli_param.option_group.optimization_group.option( - "--adaptive_learning_rate", - "-alr", - is_flag=True, - help="Whether to diminish the learning rate", -) -# Model -multi_network = cli_param.option_group.model_group.option( - "--multi_network/--single_network", - default=get_default("multi_network", config.ModelConfig), - help="If provided uses a multi-network framework.", - show_default=True, -) -dropout = cli_param.option_group.optimization_group.option( - "--dropout", - type=get_type("dropout", config.ModelConfig), - default=get_default("dropout", config.ModelConfig), - help="Rate value applied to dropout layers in a CNN architecture.", - show_default=True, -) -# Optimization -accumulation_steps = cli_param.option_group.optimization_group.option( - "--accumulation_steps", - "-asteps", - type=get_type("accumulation_steps", config.OptimizationConfig), - default=get_default("accumulation_steps", config.OptimizationConfig), - help="Accumulates gradients during the given number of iterations before performing the weight update " - "in order to virtually increase the size of the batch.", - show_default=True, -) -epochs = cli_param.option_group.optimization_group.option( - "--epochs", - type=get_type("epochs", config.OptimizationConfig), - default=get_default("epochs", config.OptimizationConfig), - help="Maximum number of epochs.", - show_default=True, -) -profiler = cli_param.option_group.optimization_group.option( - "--profiler/--no-profiler", - default=get_default("profiler", config.OptimizationConfig), - help="Use `--profiler` to enable Pytorch profiler for the first 30 steps after a short warmup. " - "It will make an execution trace and some statistics about the CPU and GPU usage.", - show_default=True, -) -# Optimizer -learning_rate = cli_param.option_group.optimization_group.option( - "--learning_rate", - "-lr", - type=get_type("learning_rate", config.OptimizerConfig), - default=get_default("learning_rate", config.OptimizerConfig), - help="Learning rate of the optimization.", - show_default=True, -) -optimizer = cli_param.option_group.optimization_group.option( - "--optimizer", - type=click.Choice(get_type("optimizer", config.OptimizerConfig)), - default=get_default("optimizer", config.OptimizerConfig), - help="Optimizer used to train the network.", - show_default=True, -) -weight_decay = cli_param.option_group.optimization_group.option( - "--weight_decay", - "-wd", - type=get_type("weight_decay", config.OptimizerConfig), - default=get_default("weight_decay", config.OptimizerConfig), - help="Weight decay value used in optimization.", - show_default=True, -) -# Reproducibility -compensation = cli_param.option_group.reproducibility_group.option( - "--compensation", - type=click.Choice(get_type("compensation", config.ReproducibilityConfig)), - default=get_default("compensation", config.ReproducibilityConfig), - help="Allow the user to choose how CUDA will compensate the deterministic behaviour.", - show_default=True, -) -deterministic = cli_param.option_group.reproducibility_group.option( - "--deterministic/--nondeterministic", - default=get_default("deterministic", config.ReproducibilityConfig), - help="Forces Pytorch to be deterministic even when using a GPU. " - "Will raise a RuntimeError if a non-deterministic function is encountered.", - show_default=True, -) -save_all_models = cli_param.option_group.reproducibility_group.option( - "--save_all_models/--save_only_best_model", - type=get_type("save_all_models", config.ReproducibilityConfig), - default=get_default("save_all_models", config.ReproducibilityConfig), - help="If provided, enables the saving of models weights for each epochs.", - show_default=True, -) -seed = cli_param.option_group.reproducibility_group.option( - "--seed", - type=get_type("seed", config.ReproducibilityConfig), - default=get_default("seed", config.ReproducibilityConfig), - help="Value to set the seed for all random operations." - "Default will sample a random value for the seed.", - show_default=True, -) -# SSDA -caps_target = cli_param.option_group.data_group.option( - "--caps_target", - "-d", - type=get_type("caps_target", config.SSDAConfig), - default=get_default("caps_target", config.SSDAConfig), - help="CAPS of target data.", - show_default=True, -) -preprocessing_json_target = cli_param.option_group.data_group.option( - "--preprocessing_json_target", - "-d", - type=get_type("preprocessing_json_target", config.SSDAConfig), - default=get_default("preprocessing_json_target", config.SSDAConfig), - help="Path to json target.", - show_default=True, -) -ssda_network = cli_param.option_group.model_group.option( - "--ssda_network/--single_network", - default=get_default("ssda_network", config.SSDAConfig), - help="If provided uses a ssda-network framework.", - show_default=True, -) -tsv_target_lab = cli_param.option_group.data_group.option( - "--tsv_target_lab", - "-d", - type=get_type("tsv_target_lab", config.SSDAConfig), - default=get_default("tsv_target_lab", config.SSDAConfig), - help="TSV of labeled target data.", - show_default=True, -) -tsv_target_unlab = cli_param.option_group.data_group.option( - "--tsv_target_unlab", - "-d", - type=get_type("tsv_target_unlab", config.SSDAConfig), - default=get_default("tsv_target_unlab", config.SSDAConfig), - help="TSV of unllabeled target data.", - show_default=True, -) -# Transfer Learning -nb_unfrozen_layer = cli_param.option_group.transfer_learning_group.option( - "-nul", - "--nb_unfrozen_layer", - type=get_type("nb_unfrozen_layer", config.TransferLearningConfig), - default=get_default("nb_unfrozen_layer", config.TransferLearningConfig), - help="Number of layer that will be retrain during training. For example, if it is 2, the last two layers of the model will not be freezed.", - show_default=True, -) -transfer_path = cli_param.option_group.transfer_learning_group.option( - "-tp", - "--transfer_path", - type=get_type("transfer_path", config.TransferLearningConfig), - default=get_default("transfer_path", config.TransferLearningConfig), - help="Path of to a MAPS used for transfer learning.", - show_default=True, -) -transfer_selection_metric = cli_param.option_group.transfer_learning_group.option( - "-tsm", - "--transfer_selection_metric", - type=get_type("transfer_selection_metric", config.TransferLearningConfig), - default=get_default("transfer_selection_metric", config.TransferLearningConfig), - help="Metric used to select the model for transfer learning in the MAPS defined by transfer_path.", - show_default=True, -) -# Transform -data_augmentation = cli_param.option_group.data_group.option( - "--data_augmentation", - "-da", - type=click.Choice(get_type("data_augmentation", config.TransformsConfig)), - default=get_default("data_augmentation", config.TransformsConfig), - multiple=True, - help="Randomly applies transforms on the training set.", - show_default=True, -) -normalize = cli_param.option_group.data_group.option( - "--normalize/--unnormalize", - default=get_default("normalize", config.TransformsConfig), - help="Disable default MinMaxNormalization.", - show_default=True, -) -# Validation -valid_longitudinal = cli_param.option_group.data_group.option( - "--valid_longitudinal/--valid_baseline", - default=get_default("valid_longitudinal", config.ValidationConfig), - help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", - show_default=True, -) -evaluation_steps = cli_param.option_group.computational_group.option( - "--evaluation_steps", - "-esteps", - type=get_type("evaluation_steps", config.ValidationConfig), - default=get_default("evaluation_steps", config.ValidationConfig), - help="Fix the number of iterations to perform before computing an evaluation. Default will only " - "perform one evaluation at the end of each epoch.", - show_default=True, -) diff --git a/clinicadl/train/train_cli.py b/clinicadl/train/train_cli.py index 2eaa3d42d..a773ac076 100644 --- a/clinicadl/train/train_cli.py +++ b/clinicadl/train/train_cli.py @@ -3,9 +3,9 @@ from .from_json import cli as from_json_cli from .list_models import cli as list_models_cli from .resume import cli as resume_cli -from .tasks.classification.classification_cli import cli as classification_cli -from .tasks.reconstruction.reconstruction_cli import cli as reconstruction_cli -from .tasks.regression.regression_cli import cli as regression_cli +from .tasks.classification_cli import cli as classification_cli +from .tasks.reconstruction_cli import cli as reconstruction_cli +from .tasks.regression_cli import cli as regression_cli @click.group(name="train", no_args_is_help=True) diff --git a/clinicadl/train/trainer/__init__.py b/clinicadl/train/trainer/__init__.py index bd83a7d3d..260e4c8d6 100644 --- a/clinicadl/train/trainer/__init__.py +++ b/clinicadl/train/trainer/__init__.py @@ -1,21 +1 @@ from .trainer import Trainer -from .training_config import ( - CallbacksConfig, - ComputationalConfig, - CrossValidationConfig, - DataConfig, - DataLoaderConfig, - EarlyStoppingConfig, - LRschedulerConfig, - MapsManagerConfig, - ModelConfig, - OptimizationConfig, - OptimizerConfig, - ReproducibilityConfig, - SSDAConfig, - Task, - TrainingConfig, - TransferLearningConfig, - TransformsConfig, - ValidationConfig, -) diff --git a/clinicadl/train/trainer/trainer.py b/clinicadl/train/trainer/trainer.py index d7a7367e7..1fbcf9534 100644 --- a/clinicadl/train/trainer/trainer.py +++ b/clinicadl/train/trainer/trainer.py @@ -25,14 +25,13 @@ from clinicadl.utils.maps_manager import MapsManager from clinicadl.utils.seed import get_seed -from .training_config import Task +from clinicadl.utils.enum import Task from .trainer_utils import create_parameters_dict if TYPE_CHECKING: + from clinicadl.config.config.pipelines.train import TrainConfig from clinicadl.utils.callbacks.callbacks import Callback - from .training_config import TrainingConfig - logger = getLogger("clinicadl.trainer") @@ -41,7 +40,7 @@ class Trainer: def __init__( self, - config: TrainingConfig, + config: TrainConfig, maps_manager: Optional[MapsManager] = None, ) -> None: """ diff --git a/clinicadl/train/utils.py b/clinicadl/train/utils.py index d641c4df0..48713ecf9 100644 --- a/clinicadl/train/utils.py +++ b/clinicadl/train/utils.py @@ -5,12 +5,11 @@ import toml from click.core import ParameterSource +from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ClinicaDLConfigurationError from clinicadl.utils.maps_manager.maps_manager_utils import remove_unused_tasks from clinicadl.utils.preprocessing import path_decoder -from .trainer import Task - def extract_config_from_toml_file(config_file: Path, task: Task) -> Dict[str, Any]: """ diff --git a/clinicadl/utils/config_utils.py b/clinicadl/utils/config_utils.py index a15cb6163..53928e2fd 100644 --- a/clinicadl/utils/config_utils.py +++ b/clinicadl/utils/config_utils.py @@ -137,4 +137,3 @@ def get_type_from_config_class(arg: str, config: BaseModel) -> Any: return click.Choice(list([option.value for option in type_])) else: return type_ - # raise TypeError(f"the type {type_} is not supported for the argument {arg}.") diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 631900657..e021f53eb 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -6,7 +6,7 @@ import pytest -from clinicadl.predict.predict_config import InterpretConfig +from clinicadl.config.config.pipelines.interpret import InterpretConfig from clinicadl.predict.predict_manager import PredictManager @@ -64,7 +64,7 @@ def test_interpret(cmdopt, tmp_path, test_name): def run_interpret(cnn_input, tmp_out_dir, ref_dir): - from clinicadl.predict.predict_config import InterpretationMethod + from clinicadl.utils.enum import InterpretationMethod maps_path = tmp_out_dir / "maps" if maps_path.is_dir(): diff --git a/tests/test_predict.py b/tests/test_predict.py index e3ef19fef..b6055e6b5 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -73,7 +73,7 @@ def test_predict(cmdopt, tmp_path, test_name): with open(json_path, "w") as f: f.write(json_data) - from clinicadl.predict.predict_config import PredictConfig + from clinicadl.config.config.pipelines.predict import PredictConfig predict_config = PredictConfig( maps_dir=model_folder, diff --git a/tests/unittests/train/tasks/classification/test_classification_config.py b/tests/unittests/train/tasks/classification/test_classification_config.py index 08636741a..8f38b4825 100644 --- a/tests/unittests/train/tasks/classification/test_classification_config.py +++ b/tests/unittests/train/tasks/classification/test_classification_config.py @@ -3,13 +3,13 @@ import pytest from pydantic import ValidationError -import clinicadl.train.tasks.classification.config as config +import clinicadl.config.config.pipelines.task.classification as classification # Tests for customed validators # def test_model_config(): with pytest.raises(ValidationError): - config.ModelConfig( + classification.ModelConfig( **{ "architecture": "", "loss": "", @@ -19,7 +19,7 @@ def test_model_config(): def test_validation_config(): - c = config.ValidationConfig(selection_metrics=["accuracy"]) + c = classification.ValidationConfig(selection_metrics=["accuracy"]) assert c.selection_metrics == ("accuracy",) @@ -64,11 +64,11 @@ def good_inputs(dummy_arguments): def test_fails_validations(bad_inputs): with pytest.raises(ValidationError): - config.ClassificationConfig(**bad_inputs) + classification.ClassificationConfig(**bad_inputs) def test_passes_validations(good_inputs): - c = config.ClassificationConfig(**good_inputs) + c = classification.ClassificationConfig(**good_inputs) assert c.model.loss == "MultiMarginLoss" assert c.validation.selection_metrics == ("F1_score",) assert c.model.selection_threshold == 0.5 diff --git a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py index 33d2b3f8d..d1e3855dc 100644 --- a/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py +++ b/tests/unittests/train/tasks/reconstruction/test_reconstruction_config.py @@ -3,12 +3,12 @@ import pytest from pydantic import ValidationError -import clinicadl.train.tasks.reconstruction.config as config +import clinicadl.config.config.pipelines.task.reconstruction as reconstruction # Tests for customed validators # def test_validation_config(): - c = config.ValidationConfig(selection_metrics=["MAE"]) + c = reconstruction.ValidationConfig(selection_metrics=["MAE"]) assert c.selection_metrics == ("MAE",) @@ -53,11 +53,11 @@ def good_inputs(dummy_arguments): def test_fails_validations(bad_inputs): with pytest.raises(ValidationError): - config.ReconstructionConfig(**bad_inputs) + reconstruction.ReconstructionConfig(**bad_inputs) def test_passes_validations(good_inputs): - c = config.ReconstructionConfig(**good_inputs) + c = reconstruction.ReconstructionConfig(**good_inputs) assert c.model.loss == "HuberLoss" assert c.validation.selection_metrics == ("PSNR",) assert c.model.normalization == "batch" diff --git a/tests/unittests/train/tasks/regression/test_regression_config.py b/tests/unittests/train/tasks/regression/test_regression_config.py index 1c28e7a8e..b159791ee 100644 --- a/tests/unittests/train/tasks/regression/test_regression_config.py +++ b/tests/unittests/train/tasks/regression/test_regression_config.py @@ -3,12 +3,12 @@ import pytest from pydantic import ValidationError -import clinicadl.train.tasks.regression.config as config +import clinicadl.config.config.pipelines.task.regression as regression # Tests for customed validators # def test_validation_config(): - c = config.ValidationConfig(selection_metrics=["R2_score"]) + c = regression.ValidationConfig(selection_metrics=["R2_score"]) assert c.selection_metrics == ("R2_score",) @@ -52,11 +52,11 @@ def good_inputs(dummy_arguments): def test_fails_validations(bad_inputs): with pytest.raises(ValidationError): - config.RegressionConfig(**bad_inputs) + regression.RegressionConfig(**bad_inputs) def test_passes_validations(good_inputs): - c = config.RegressionConfig(**good_inputs) + c = regression.RegressionConfig(**good_inputs) assert c.model.loss == "KLDivLoss" assert c.validation.selection_metrics == ("R2_score",) assert c.network_task == "regression" diff --git a/tests/unittests/train/test_utils.py b/tests/unittests/train/test_utils.py index d71e6f980..a4b054817 100644 --- a/tests/unittests/train/test_utils.py +++ b/tests/unittests/train/test_utils.py @@ -2,7 +2,7 @@ import pytest -from clinicadl.train.trainer import Task +from clinicadl.utils.enum import Task expected_classification = { "architecture": "default", diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 1191e76bc..5f9e0acca 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -import clinicadl.train.trainer.training_config as config +import clinicadl.config.config as config # Tests for customed validators # @@ -120,7 +120,9 @@ def dummy_arguments(caps_example): def training_config(): from pydantic import computed_field - class TrainingConfig(config.TrainingConfig): + from clinicadl.config.config.pipelines.train import TrainConfig + + class TrainingConfig(TrainConfig): @computed_field @property def network_task(self) -> str: