From 9bcd8b2a360050d299b2aafc17e75b031b209c7f Mon Sep 17 00:00:00 2001 From: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> Date: Thu, 6 Jun 2024 14:29:00 +0200 Subject: [PATCH] Put resume in trainer (#612) --- .../pipelines/train/from_json/cli.py | 22 +--- .../commandline/pipelines/train/resume/cli.py | 6 +- clinicadl/config/config/cross_validation.py | 2 +- clinicadl/predict/predict_manager.py | 2 +- clinicadl/train/resume.py | 82 ------------ clinicadl/trainer/trainer.py | 122 ++++++++++++++++-- clinicadl/trainer/trainer_utils.py | 16 +++ clinicadl/utils/maps_manager/maps_manager.py | 36 +++++- clinicadl/utils/meta_maps/getter.py | 2 +- 9 files changed, 168 insertions(+), 122 deletions(-) delete mode 100644 clinicadl/train/resume.py diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py index b37d2e253..ab613d330 100644 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ b/clinicadl/commandline/pipelines/train/from_json/cli.py @@ -1,14 +1,12 @@ from logging import getLogger -from pathlib import Path import click from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( cross_validation, - reproducibility, ) -from clinicadl.train.tasks_utils import create_training_config +from clinicadl.trainer.trainer import Trainer @click.command(name="from_json", no_args_is_help=True) @@ -24,23 +22,11 @@ def cli(**kwargs): OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved. """ - from clinicadl.trainer.trainer import Trainer - from clinicadl.utils.maps_manager.maps_manager_utils import read_json logger = getLogger("clinicadl") logger.info(f"Reading JSON file at path {kwargs['config_file']}...") - config_dict = read_json(kwargs["config_file"]) - # temporary - config_dict["tsv_directory"] = config_dict["tsv_path"] - if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""): - config_dict["track_exp"] = None - config_dict["maps_dir"] = kwargs["output_maps_directory"] - config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][ - "extract_json" - ] - ### - config = create_training_config(config_dict["network_task"])( - output_maps_directory=kwargs["output_maps_directory"], **config_dict + + trainer = Trainer.from_json( + config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"] ) - trainer = Trainer(config) trainer.train(split_list=kwargs["split"], overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index ee5cea61e..88c4f6bc0 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -4,6 +4,7 @@ from clinicadl.commandline.modules_options import ( cross_validation, ) +from clinicadl.trainer import Trainer @click.command(name="resume", no_args_is_help=True) @@ -14,6 +15,5 @@ def cli(input_maps_directory, split): INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started. """ - from clinicadl.train.resume import automatic_resume - - automatic_resume(input_maps_directory, user_split_list=split) + trainer = Trainer.from_maps(input_maps_directory) + trainer.resume(split) diff --git a/clinicadl/config/config/cross_validation.py b/clinicadl/config/config/cross_validation.py index fd2b4cb40..3441d72d1 100644 --- a/clinicadl/config/config/cross_validation.py +++ b/clinicadl/config/config/cross_validation.py @@ -34,5 +34,5 @@ def validator_split(cls, v): def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager): # TEMPORARY if not self.split: - self.split = maps_manager._find_splits() + self.split = maps_manager.find_splits() logger.debug(f"List of splits {self.split}") diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index 8525e12de..79f964328 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -820,7 +820,7 @@ def _check_data_group( raise MAPSError("Cannot overwrite train or validation data group.") else: # if not split_list: - # split_list = self.maps_manager._find_splits() + # split_list = self.maps_manager.find_splits() assert self._config.split for split in self._config.split: selection_metrics = self.maps_manager._find_selection_metrics( diff --git a/clinicadl/train/resume.py b/clinicadl/train/resume.py deleted file mode 100644 index b4ec16ba8..000000000 --- a/clinicadl/train/resume.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Automatic relaunch of jobs that were stopped before the end of training. -Unfinished splits are detected as they do not contain a "performances" sub-folder -""" -# TODO: Remove this file and put everything in trainer.resume() ?? -from logging import getLogger -from pathlib import Path - -from clinicadl.train.tasks_utils import create_training_config -from clinicadl.trainer.trainer import Trainer -from clinicadl.utils.maps_manager import MapsManager - - -def replace_arg(options, key_name, value): - if value is not None: - setattr(options, key_name, value) - - -def automatic_resume(model_path: Path, user_split_list=None, verbose=0): - logger = getLogger("clinicadl") - - verbose_list = ["warning", "info", "debug"] - maps_manager = MapsManager(model_path, verbose=verbose_list[verbose]) - config_dict = maps_manager.get_parameters() - # temporary, TODO - config_dict["tsv_directory"] = config_dict["tsv_path"] - if config_dict["track_exp"] == "": - config_dict["track_exp"] = None - if "label_code" not in config_dict or config_dict["label_code"] is None: - config_dict["label_code"] = {} - if "preprocessing_json" not in config_dict: - config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][ - "extract_json" - ] - config_dict["maps_dir"] = model_path - ### - config = create_training_config(config_dict["network_task"])( - output_maps_directory=model_path, **config_dict - ) - trainer = Trainer(config, maps_manager=maps_manager) - - existing_split_list = maps_manager._find_splits() - stopped_splits = [ - split - for split in existing_split_list - if (model_path / f"{maps_manager.split_name}-{split}" / "tmp") - in list((model_path / f"{maps_manager.split_name}-{split}").iterdir()) - ] - - # Find finished split - finished_splits = list() - for split in existing_split_list: - if split not in stopped_splits: - performance_dir_list = [ - performance_dir - for performance_dir in list( - (model_path / f"{maps_manager.split_name}-{split}").iterdir() - ) - if "best-" in performance_dir.name - ] - if len(performance_dir_list) > 0: - finished_splits.append(split) - - split_manager = maps_manager._init_split_manager(split_list=user_split_list) - split_iterator = split_manager.split_iterator() - - absent_splits = [ - split - for split in split_iterator - if split not in finished_splits and split not in stopped_splits - ] - - # To ensure retro-compatibility with random search - logger.info( - f"Finished splits {finished_splits}\n" - f"Stopped splits {stopped_splits}\n" - f"Absent splits {absent_splits}" - ) - if len(stopped_splits) > 0: - trainer.resume(stopped_splits) - if len(absent_splits) > 0: - trainer.train(absent_splits, overwrite=True) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 7cf37ff8f..22b8e01b4 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -19,14 +19,15 @@ from clinicadl.utils.exceptions import MAPSError from clinicadl.utils.maps_manager.ddp import DDP, cluster from clinicadl.utils.maps_manager.logwriter import LogWriter +from clinicadl.utils.maps_manager.maps_manager_utils import read_json from clinicadl.utils.metric_module import RetainBest from clinicadl.utils.seed import pl_worker_init_function, seed_everything from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.maps_manager import MapsManager from clinicadl.utils.seed import get_seed - from clinicadl.utils.enum import Task -from .trainer_utils import create_parameters_dict +from .trainer_utils import create_parameters_dict, patch_to_read_json +from clinicadl.train.tasks_utils import create_training_config if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback @@ -42,33 +43,128 @@ class Trainer: def __init__( self, config: TrainConfig, - maps_manager: Optional[MapsManager] = None, ) -> None: """ Parameters ---------- - config : BaseTaskConfig + config : TrainConfig """ self.config = config - if maps_manager: - self.maps_manager = maps_manager - else: - self.maps_manager = self._init_maps_manager(config) + self.maps_manager = self._init_maps_manager(config) self._check_args() def _init_maps_manager(self, config) -> MapsManager: # temporary: to match CLI data. TODO : change CLI data parameters, maps_path = create_parameters_dict(config) - return MapsManager( - maps_path, parameters, verbose=None - ) # TODO : precise which parameters in config are useful + if maps_path.is_dir(): + return MapsManager( + maps_path, verbose=None + ) # TODO : precise which parameters in config are useful + else: + return MapsManager( + maps_path, parameters, verbose=None + ) # TODO : precise which parameters in config are useful + + @classmethod + def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer: + """ + Creates a Trainer from a json configuration file. + + Parameters + ---------- + config_file : str | Path + The parameters, stored in a json files. + maps_path : str | Path + The folder where the results of a futur training will be stored. + + Returns + ------- + Trainer + The Trainer object, instantiated with parameters found in config_file. + + Raises + ------ + FileNotFoundError + If config_file doesn't exist. + """ + config_file = Path(config_file) + + if not (config_file).is_file(): + raise FileNotFoundError(f"No file found at {str(config_file)}.") + config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch + config_dict["maps_dir"] = maps_path + config_object = create_training_config(config_dict["network_task"])( + **config_dict + ) + return cls(config_object) + + @classmethod + def from_maps(cls, maps_path: str | Path) -> Trainer: + """ + Creates a Trainer from a json configuration file. + + Parameters + ---------- + maps_path : str | Path + The path of the MAPS folder. + + Returns + ------- + Trainer + The Trainer object, instantiated with parameters found in maps_path. + + Raises + ------ + MAPSError + If maps_path folder doesn't exist or there is no maps.json file in it. + """ + maps_path = Path(maps_path) + + if not (maps_path / "maps.json").is_file(): + raise MAPSError( + f"MAPS was not found at {str(maps_path)}." + f"To initiate a new MAPS please give a train_dict." + ) + return cls.from_json(maps_path / "maps.json", maps_path) + + def resume(self, splits: List[int]) -> None: + """ + Resume a prematurely stopped training. + + Parameters + ---------- + splits : List[int] + The splits that must be resumed. + """ + stopped_splits = set(self.maps_manager.find_stopped_splits()) + finished_splits = set(self.maps_manager.find_finished_splits()) + # TODO : check these two lines. Why do we need a split_manager? + split_manager = self.maps_manager._init_split_manager(split_list=splits) + split_iterator = split_manager.split_iterator() + ### + absent_splits = set(split_iterator) - stopped_splits - finished_splits + + logger.info( + f"Finished splits {finished_splits}\n" + f"Stopped splits {stopped_splits}\n" + f"Absent splits {absent_splits}" + ) + + if len(stopped_splits) == 0 and len(absent_splits) == 0: + raise ValueError( + "Training has been completed on all the splits you passed." + ) + if len(stopped_splits) > 0: + self._resume(list(stopped_splits)) + if len(absent_splits) > 0: + self.train(list(absent_splits), overwrite=True) def _check_args(self): self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed) # if (len(self.config.data.label_code) == 0): # self.config.data.label_code = self.maps_manager.label_code - # TODO : deal with label_code and replace self.maps_manager.label_code + # TODO: deal with label_code and replace self.maps_manager.label_code def train( self, @@ -120,7 +216,7 @@ def train( else: self._train_single(split_list, resume=False) - def resume( + def _resume( self, split_list: Optional[List[int]] = None, ) -> None: diff --git a/clinicadl/trainer/trainer_utils.py b/clinicadl/trainer/trainer_utils.py index 82af17489..58c8103ef 100644 --- a/clinicadl/trainer/trainer_utils.py +++ b/clinicadl/trainer/trainer_utils.py @@ -69,3 +69,19 @@ def create_parameters_dict(config): if "train_transformations" in parameters: del parameters["train_transformations"] return parameters, maps_path + + +def patch_to_read_json(config_dict): + config_dict["tsv_directory"] = config_dict["tsv_path"] + if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""): + config_dict["track_exp"] = None + config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][ + "extract_json" + ] + if "label_code" not in config_dict or config_dict["label_code"] is None: + config_dict["label_code"] = {} + if "preprocessing_json" not in config_dict: + config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][ + "extract_json" + ] + return config_dict diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/utils/maps_manager/maps_manager.py index d710d955c..bebd98377 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/utils/maps_manager/maps_manager.py @@ -332,13 +332,43 @@ def _compute_output_tensors( torch.save(output, tensor_path / output_filename) logger.debug(f"File saved at {[input_filename, output_filename]}") - def _find_splits(self): - """Find which splits were trained in the MAPS.""" - return [ + def find_splits(self) -> List[int]: + """Find which splits that were trained in the MAPS.""" + splits = [ int(split.name.split("-")[1]) for split in list(self.maps_path.iterdir()) if split.name.startswith(f"{self.split_name}-") ] + return splits + + def find_stopped_splits(self) -> List[int]: + """Find which splits for which training was not completed.""" + existing_split_list = self.find_splits() + stopped_splits = [ + split + for split in existing_split_list + if (self.maps_path / f"{self.split_name}-{split}" / "tmp") + in list((self.maps_path / f"{self.split_name}-{split}").iterdir()) + ] + return stopped_splits + + def find_finished_splits(self) -> List[int]: + """Find which splits for which training was completed.""" + finished_splits = list() + existing_split_list = self.find_splits() + stopped_splits = self.find_stopped_splits() + for split in existing_split_list: + if split not in stopped_splits: + performance_dir_list = [ + performance_dir + for performance_dir in list( + (self.maps_path / f"{self.split_name}-{split}").iterdir() + ) + if "best-" in performance_dir.name + ] + if len(performance_dir_list) > 0: + finished_splits.append(split) + return finished_splits def _ensemble_prediction( self, diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 42967c929..38307b11d 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -34,7 +34,7 @@ def meta_maps_analysis(launch_dir: Path, evaluation_metric="loss"): for job in jobs_list: performances_dict[job] = dict() maps_manager = MapsManager(launch_dir / job) - split_list = maps_manager._find_splits() + split_list = maps_manager.find_splits() split_set = split_set | set(split_list) for split in split_set: performances_dict[job][split] = dict()