diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index cb3cdf6d7..8b8b95554 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -1,4 +1,5 @@ import json +import shutil import subprocess from datetime import datetime from logging import getLogger @@ -50,7 +51,7 @@ def __init__( self, maps_path: Path, parameters: Optional[Dict[str, Any]] = None, - verbose: str = "info", + verbose: Optional[str] = "info", ): """ @@ -569,13 +570,13 @@ def _mode_to_image_tsv( ############################### def _init_model( self, - transfer_path: Path = None, - transfer_selection=None, - nb_unfrozen_layer=0, - split=None, - resume=False, - gpu=None, - network=None, + transfer_path: Optional[Path] = None, + transfer_selection: Optional[str] = None, + nb_unfrozen_layer: int = 0, + split: Optional[int] = None, + resume: bool = False, + gpu: bool = False, + network: Optional[int] = None, ): """ Instantiate the model @@ -778,3 +779,67 @@ def std_amp(self) -> bool: then calls the internal FSDP AMP mechanisms. """ return self.amp and not self.fully_sharded_data_parallel + + def _erase_tmp(self, split: int): + """ + Erases checkpoints of the model and optimizer at the end of training. + + Parameters + ---------- + split : int + The split on which the model has been trained. + """ + tmp_path = self.maps_path / f"split-{split}" / "tmp" + shutil.rmtree(tmp_path) + + def _write_weights( + self, + state: Dict[str, Any], + metrics_dict: Optional[Dict[str, bool]], + split: int, + network: Optional[int] = None, + filename: str = "checkpoint.pth.tar", + save_all_models: bool = False, + ) -> None: + """ + Update checkpoint and save the best model according to a set of + metrics. + + Parameters + ---------- + state : Dict[str, Any] + The state of the training (model weights, epoch, etc.). + metrics_dict : Optional[Dict[str, bool]] + The output of RetainBest step. If None, only the checkpoint + is saved. + split : int + The split number. + network : int (optional, default=None) + The network number (multi-network framework). + filename : str (optional, default="checkpoint.pth.tar") + The name of the checkpoint file. + save_all_models : bool (optional, default=False) + Whether to save model weights at every epoch. + If False, only the best model will be saved. + """ + checkpoint_dir = self.maps_path / f"split-{split}" / "tmp" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + checkpoint_path = checkpoint_dir / filename + torch.save(state, checkpoint_path) + + if save_all_models: + all_models_dir = self.maps_path / f"split-{split}" / "all_models" + all_models_dir.mkdir(parents=True, exist_ok=True) + torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar") + + best_filename = "model.pth.tar" + if network is not None: + best_filename = f"network-{network}_model.pth.tar" + + # Save model according to several metrics + if metrics_dict is not None: + for metric_name, metric_bool in metrics_dict.items(): + metric_path = self.maps_path / f"split-{split}" / f"best-{metric_name}" + if metric_bool: + metric_path.mkdir(parents=True, exist_ok=True) + shutil.copyfile(checkpoint_path, metric_path / best_filename) diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py index a004dd510..53413fdda 100644 --- a/clinicadl/splitter/config.py +++ b/clinicadl/splitter/config.py @@ -23,7 +23,7 @@ class SplitConfig(BaseModel): n_splits: NonNegativeInt = 0 split: Optional[Tuple[NonNegativeInt, ...]] = None - tsv_path: Path # not needed in predict ? + tsv_path: Optional[Path] = None # not needed in interpret ! # pydantic config model_config = ConfigDict(validate_assignment=True) @@ -39,7 +39,7 @@ def adapt_cross_val_with_maps_manager_info( ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: - self.split = find_splits(maps_manager.maps_path) + self.split = tuple(find_splits(maps_manager.maps_path)) logger.debug(f"List of splits {self.split}") diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index 1bf5ca457..3e0f09388 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional +from typing import List def find_splits(maps_path: Path) -> List[int]: @@ -7,7 +7,7 @@ def find_splits(maps_path: Path) -> List[int]: splits = [ int(split.name.split("-")[1]) for split in list(maps_path.iterdir()) - if split.name.startswith(f"split-") + if split.name.startswith("split-") ] return splits diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/splitter.py index 83d1ab127..f263366eb 100644 --- a/clinicadl/splitter/splitter.py +++ b/clinicadl/splitter/splitter.py @@ -5,15 +5,7 @@ import pandas as pd -from clinicadl.caps_dataset.data_config import DataConfig -from clinicadl.splitter.config import SplitConfig, SplitterConfig -from clinicadl.splitter.validation import ValidationConfig -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLConfigurationError, - ClinicaDLTSVError, -) -from clinicadl.utils.iotools.clinica_utils import check_caps_folder +from clinicadl.splitter.config import SplitterConfig logger = getLogger("clinicadl.split_manager") @@ -147,9 +139,7 @@ def get_dataframe_from_tsv_path(tsv_path: Path) -> pd.DataFrame: return df @staticmethod - def load_data( - tsv_path: Path, cohort_diagnoses: Optional[List[str]] = None - ) -> pd.DataFrame: + def load_data(tsv_path: Path, cohort_diagnoses: List[str]) -> pd.DataFrame: df = Splitter.get_dataframe_from_tsv_path(tsv_path) df = df[df.diagnosis.isin((cohort_diagnoses))] df.reset_index(inplace=True, drop=True) @@ -164,7 +154,7 @@ def concatenate_diagnoses( """Concatenated the diagnoses needed to form the train and validation sets.""" if cohort_diagnoses is None: - cohort_diagnoses = self.config.data.diagnoses + cohort_diagnoses = list(self.config.data.diagnoses) tmp_cohort_path = ( cohort_path if cohort_path is not None else self.config.split.tsv_path diff --git a/clinicadl/splitter/validation.py b/clinicadl/splitter/validation.py index 6724a48f3..1452b47da 100644 --- a/clinicadl/splitter/validation.py +++ b/clinicadl/splitter/validation.py @@ -1,12 +1,9 @@ from logging import getLogger -from pathlib import Path -from typing import Optional, Tuple +from typing import Tuple from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt -from clinicadl.splitter.split_utils import find_splits - logger = getLogger("clinicadl.validation_config") diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 814aba2f8..a7856e6fd 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -5,13 +5,13 @@ from datetime import datetime from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable import pandas as pd import torch import torch.distributed as dist -from torch.amp import GradScaler -from torch.amp import autocast +from torch.amp.grad_scaler import GradScaler +from torch.amp.autocast_mode import autocast from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -311,11 +311,11 @@ def _resume( else: self._train_single(split, split_df_dict, resume=True) - def init_first_network(self, resume, split): + def init_first_network(self, resume: bool, split: int): first_network = 0 if resume: training_logs = [ - int(network_folder.split("-")[1]) + int(str(network_folder).split("-")[1]) for network_folder in list( ( self.maps_manager.maps_path / f"split-{split}" / "training_logs" @@ -330,40 +330,29 @@ def init_first_network(self, resume, split): def get_dataloader( self, - input_dir: Path, data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, - network_task: Union[str, Task] = "classification", sampler_option: str = "random", - n_bins: int = 5, dp_degree: Optional[int] = None, rank: Optional[int] = None, - batch_size: Optional[int] = None, - n_proc: Optional[int] = None, - worker_init_fn: Optional[function] = None, - shuffle: Optional[bool] = None, + worker_init_fn: Optional[Callable[[int], None]] = None, + shuffle: bool = True, num_replicas: Optional[int] = None, homemade_sampler: bool = False, ): dataset = return_dataset( - input_dir=input_dir, + input_dir=self.config.data.caps_directory, data_df=data_df, - preprocessing_dict=preprocessing_dict, - transforms_config=transforms_config, - multi_cohort=multi_cohort, - label=label, - label_code=label_code, + preprocessing_dict=self.config.data.preprocessing_dict, + transforms_config=self.config.transforms, + multi_cohort=self.config.data.multi_cohort, + label=self.config.data.label, + label_code=self.maps_manager.label_code, cnn_index=cnn_index, ) if homemade_sampler: sampler = generate_sampler( - network_task=network_task, + network_task=self.maps_manager.network_task, dataset=dataset, sampler_option=sampler_option, dp_degree=dp_degree, @@ -379,9 +368,9 @@ def get_dataloader( train_loader = DataLoader( dataset=dataset, - batch_size=batch_size, + batch_size=self.config.dataloader.batch_size, sampler=sampler, - num_workers=n_proc, + num_workers=self.config.dataloader.n_proc, worker_init_fn=worker_init_fn, shuffle=shuffle, ) @@ -411,20 +400,11 @@ def _train_single( logger.debug("Loading training data...") train_loader = self.get_dataloader( - input_dir=self.config.data.caps_directory, data_df=split_df_dict["train"], - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, cnn_index=network, - network_task=self.maps_manager.network_task, sampler_option=self.config.dataloader.sampler, - dp_degree=cluster.world_size, - rank=cluster.rank, - batch_size=self.config.dataloader.batch_size, - n_proc=self.config.dataloader.n_proc, + dp_degree=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore worker_init_fn=pl_worker_init_function, homemade_sampler=True, ) @@ -433,19 +413,10 @@ def _train_single( logger.debug("Loading validation data...") valid_loader = self.get_dataloader( - input_dir=self.config.data.caps_directory, data_df=split_df_dict["validation"], - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, cnn_index=network, - network_task=self.maps_manager.network_task, - num_replicas=cluster.world_size, - rank=cluster.rank, - batch_size=self.config.dataloader.batch_size, - n_proc=self.config.dataloader.n_proc, + num_replicas=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore shuffle=False, homemade_sampler=False, ) @@ -700,7 +671,7 @@ def _train( split: int, network: Optional[int] = None, resume: bool = False, - callbacks: List[Callback] = [], + callbacks: list[Callback] = [], ): """ Core function shared by train and resume. @@ -770,9 +741,10 @@ def _train( beginning_epoch=beginning_epoch, network=network, ) - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) + # retain_best = RetainBest( + # selection_metrics=list(self.config.validation.selection_metrics) + # ) ??? + epoch = beginning_epoch retain_best = RetainBest( @@ -789,9 +761,7 @@ def _train( from torch.optim.lr_scheduler import ReduceLROnPlateau # Initialize the ReduceLROnPlateau scheduler - scheduler = ReduceLROnPlateau( - optimizer, mode="min", factor=0.1, verbose=True - ) + scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1) while epoch < self.config.optimization.epochs and not early_stopping.step( metrics_valid["loss"] @@ -948,14 +918,14 @@ def _train( if cluster.master: # Save checkpoints and best models best_dict = retain_best.step(metrics_valid) - self._write_weights( + self.maps_manager._write_weights( model_weights, best_dict, split, network=network, save_all_models=self.config.reproducibility.save_all_models, ) - self._write_weights( + self.maps_manager._write_weights( optimizer_weights, None, split, @@ -1353,7 +1323,7 @@ def _train_ssdann( # Save checkpoints and best models best_dict = retain_best.step(metrics_valid_target) - self._write_weights( + self.maps_manager._write_weights( { "model": model.state_dict(), "epoch": epoch, @@ -1364,7 +1334,7 @@ def _train_ssdann( network=network, save_all_models=False, ) - self._write_weights( + self.maps_manager._write_weights( { "optimizer": optimizer.state_dict(), # TO MODIFY "epoch": epoch, @@ -1528,73 +1498,3 @@ def _init_profiler(self) -> torch.profiler.profile: profiler.step = lambda *args, **kwargs: None return profiler - - def _erase_tmp(self, split: int): - """ - Erases checkpoints of the model and optimizer at the end of training. - - Parameters - ---------- - split : int - The split on which the model has been trained. - """ - tmp_path = self.maps_manager.maps_path / f"split-{split}" / "tmp" - shutil.rmtree(tmp_path) - - def _write_weights( - self, - state: Dict[str, Any], - metrics_dict: Optional[Dict[str, bool]], - split: int, - network: Optional[int] = None, - filename: str = "checkpoint.pth.tar", - save_all_models: bool = False, - ) -> None: - """ - Update checkpoint and save the best model according to a set of - metrics. - - Parameters - ---------- - state : Dict[str, Any] - The state of the training (model weights, epoch, etc.). - metrics_dict : Optional[Dict[str, bool]] - The output of RetainBest step. If None, only the checkpoint - is saved. - split : int - The split number. - network : int (optional, default=None) - The network number (multi-network framework). - filename : str (optional, default="checkpoint.pth.tar") - The name of the checkpoint file. - save_all_models : bool (optional, default=False) - Whether to save model weights at every epoch. - If False, only the best model will be saved. - """ - checkpoint_dir = self.maps_manager.maps_path / f"split-{split}" / "tmp" - checkpoint_dir.mkdir(parents=True, exist_ok=True) - checkpoint_path = checkpoint_dir / filename - torch.save(state, checkpoint_path) - - if save_all_models: - all_models_dir = ( - self.maps_manager.maps_path / f"split-{split}" / "all_models" - ) - all_models_dir.mkdir(parents=True, exist_ok=True) - torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar") - - best_filename = "model.pth.tar" - if network is not None: - best_filename = f"network-{network}_model.pth.tar" - - # Save model according to several metrics - if metrics_dict is not None: - for metric_name, metric_bool in metrics_dict.items(): - metric_path = ( - self.maps_manager.maps_path - / f"split-{split}" - / f"best-{metric_name}" - ) - if metric_bool: - metric_path.mkdir(parents=True, exist_ok=True) - shutil.copyfile(checkpoint_path, metric_path / best_filename) diff --git a/clinicadl/utils/early_stopping/early_stopping.py b/clinicadl/utils/early_stopping/early_stopping.py index 8ea73bd5e..73a2b67cf 100644 --- a/clinicadl/utils/early_stopping/early_stopping.py +++ b/clinicadl/utils/early_stopping/early_stopping.py @@ -1,5 +1,5 @@ class EarlyStopping(object): - def __init__(self, mode="min", min_delta=0, patience=10): + def __init__(self, mode: str = "min", min_delta: float = 0, patience: int = 10): self.mode = mode self.min_delta = min_delta self.patience = patience