Skip to content

Commit

Permalink
Put resume in trainer (aramis-lab#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx authored and camillebrianceau committed Jun 10, 2024
1 parent 981e9ff commit 9bcd8b2
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 122 deletions.
22 changes: 4 additions & 18 deletions clinicadl/commandline/pipelines/train/from_json/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from logging import getLogger
from pathlib import Path

import click

from clinicadl.commandline import arguments
from clinicadl.commandline.modules_options import (
cross_validation,
reproducibility,
)
from clinicadl.train.tasks_utils import create_training_config
from clinicadl.trainer.trainer import Trainer


@click.command(name="from_json", no_args_is_help=True)
Expand All @@ -24,23 +22,11 @@ def cli(**kwargs):
OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved.
"""
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.maps_manager.maps_manager_utils import read_json

logger = getLogger("clinicadl")
logger.info(f"Reading JSON file at path {kwargs['config_file']}...")
config_dict = read_json(kwargs["config_file"])
# temporary
config_dict["tsv_directory"] = config_dict["tsv_path"]
if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""):
config_dict["track_exp"] = None
config_dict["maps_dir"] = kwargs["output_maps_directory"]
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
###
config = create_training_config(config_dict["network_task"])(
output_maps_directory=kwargs["output_maps_directory"], **config_dict

trainer = Trainer.from_json(
config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"]
)
trainer = Trainer(config)
trainer.train(split_list=kwargs["split"], overwrite=True)
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/train/resume/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from clinicadl.commandline.modules_options import (
cross_validation,
)
from clinicadl.trainer import Trainer


@click.command(name="resume", no_args_is_help=True)
Expand All @@ -14,6 +15,5 @@ def cli(input_maps_directory, split):
INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started.
"""
from clinicadl.train.resume import automatic_resume

automatic_resume(input_maps_directory, user_split_list=split)
trainer = Trainer.from_maps(input_maps_directory)
trainer.resume(split)
2 changes: 1 addition & 1 deletion clinicadl/config/config/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ def validator_split(cls, v):
def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.split:
self.split = maps_manager._find_splits()
self.split = maps_manager.find_splits()
logger.debug(f"List of splits {self.split}")
2 changes: 1 addition & 1 deletion clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def _check_data_group(
raise MAPSError("Cannot overwrite train or validation data group.")
else:
# if not split_list:
# split_list = self.maps_manager._find_splits()
# split_list = self.maps_manager.find_splits()
assert self._config.split
for split in self._config.split:
selection_metrics = self.maps_manager._find_selection_metrics(
Expand Down
82 changes: 0 additions & 82 deletions clinicadl/train/resume.py

This file was deleted.

122 changes: 109 additions & 13 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from clinicadl.utils.exceptions import MAPSError
from clinicadl.utils.maps_manager.ddp import DDP, cluster
from clinicadl.utils.maps_manager.logwriter import LogWriter
from clinicadl.utils.maps_manager.maps_manager_utils import read_json
from clinicadl.utils.metric_module import RetainBest
from clinicadl.utils.seed import pl_worker_init_function, seed_everything
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.maps_manager import MapsManager
from clinicadl.utils.seed import get_seed

from clinicadl.utils.enum import Task
from .trainer_utils import create_parameters_dict
from .trainer_utils import create_parameters_dict, patch_to_read_json
from clinicadl.train.tasks_utils import create_training_config

if TYPE_CHECKING:
from clinicadl.callbacks.callbacks import Callback
Expand All @@ -42,33 +43,128 @@ class Trainer:
def __init__(
self,
config: TrainConfig,
maps_manager: Optional[MapsManager] = None,
) -> None:
"""
Parameters
----------
config : BaseTaskConfig
config : TrainConfig
"""
self.config = config
if maps_manager:
self.maps_manager = maps_manager
else:
self.maps_manager = self._init_maps_manager(config)
self.maps_manager = self._init_maps_manager(config)
self._check_args()

def _init_maps_manager(self, config) -> MapsManager:
# temporary: to match CLI data. TODO : change CLI data

parameters, maps_path = create_parameters_dict(config)
return MapsManager(
maps_path, parameters, verbose=None
) # TODO : precise which parameters in config are useful
if maps_path.is_dir():
return MapsManager(
maps_path, verbose=None
) # TODO : precise which parameters in config are useful
else:
return MapsManager(
maps_path, parameters, verbose=None
) # TODO : precise which parameters in config are useful

@classmethod
def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer:
"""
Creates a Trainer from a json configuration file.
Parameters
----------
config_file : str | Path
The parameters, stored in a json files.
maps_path : str | Path
The folder where the results of a futur training will be stored.
Returns
-------
Trainer
The Trainer object, instantiated with parameters found in config_file.
Raises
------
FileNotFoundError
If config_file doesn't exist.
"""
config_file = Path(config_file)

if not (config_file).is_file():
raise FileNotFoundError(f"No file found at {str(config_file)}.")
config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch
config_dict["maps_dir"] = maps_path
config_object = create_training_config(config_dict["network_task"])(
**config_dict
)
return cls(config_object)

@classmethod
def from_maps(cls, maps_path: str | Path) -> Trainer:
"""
Creates a Trainer from a json configuration file.
Parameters
----------
maps_path : str | Path
The path of the MAPS folder.
Returns
-------
Trainer
The Trainer object, instantiated with parameters found in maps_path.
Raises
------
MAPSError
If maps_path folder doesn't exist or there is no maps.json file in it.
"""
maps_path = Path(maps_path)

if not (maps_path / "maps.json").is_file():
raise MAPSError(
f"MAPS was not found at {str(maps_path)}."
f"To initiate a new MAPS please give a train_dict."
)
return cls.from_json(maps_path / "maps.json", maps_path)

def resume(self, splits: List[int]) -> None:
"""
Resume a prematurely stopped training.
Parameters
----------
splits : List[int]
The splits that must be resumed.
"""
stopped_splits = set(self.maps_manager.find_stopped_splits())
finished_splits = set(self.maps_manager.find_finished_splits())
# TODO : check these two lines. Why do we need a split_manager?
split_manager = self.maps_manager._init_split_manager(split_list=splits)
split_iterator = split_manager.split_iterator()
###
absent_splits = set(split_iterator) - stopped_splits - finished_splits

logger.info(
f"Finished splits {finished_splits}\n"
f"Stopped splits {stopped_splits}\n"
f"Absent splits {absent_splits}"
)

if len(stopped_splits) == 0 and len(absent_splits) == 0:
raise ValueError(
"Training has been completed on all the splits you passed."
)
if len(stopped_splits) > 0:
self._resume(list(stopped_splits))
if len(absent_splits) > 0:
self.train(list(absent_splits), overwrite=True)

def _check_args(self):
self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed)
# if (len(self.config.data.label_code) == 0):
# self.config.data.label_code = self.maps_manager.label_code
# TODO : deal with label_code and replace self.maps_manager.label_code
# TODO: deal with label_code and replace self.maps_manager.label_code

def train(
self,
Expand Down Expand Up @@ -120,7 +216,7 @@ def train(
else:
self._train_single(split_list, resume=False)

def resume(
def _resume(
self,
split_list: Optional[List[int]] = None,
) -> None:
Expand Down
16 changes: 16 additions & 0 deletions clinicadl/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,19 @@ def create_parameters_dict(config):
if "train_transformations" in parameters:
del parameters["train_transformations"]
return parameters, maps_path


def patch_to_read_json(config_dict):
config_dict["tsv_directory"] = config_dict["tsv_path"]
if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""):
config_dict["track_exp"] = None
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
if "label_code" not in config_dict or config_dict["label_code"] is None:
config_dict["label_code"] = {}
if "preprocessing_json" not in config_dict:
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
return config_dict
36 changes: 33 additions & 3 deletions clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,43 @@ def _compute_output_tensors(
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def _find_splits(self):
"""Find which splits were trained in the MAPS."""
return [
def find_splits(self) -> List[int]:
"""Find which splits that were trained in the MAPS."""
splits = [
int(split.name.split("-")[1])
for split in list(self.maps_path.iterdir())
if split.name.startswith(f"{self.split_name}-")
]
return splits

def find_stopped_splits(self) -> List[int]:
"""Find which splits for which training was not completed."""
existing_split_list = self.find_splits()
stopped_splits = [
split
for split in existing_split_list
if (self.maps_path / f"{self.split_name}-{split}" / "tmp")
in list((self.maps_path / f"{self.split_name}-{split}").iterdir())
]
return stopped_splits

def find_finished_splits(self) -> List[int]:
"""Find which splits for which training was completed."""
finished_splits = list()
existing_split_list = self.find_splits()
stopped_splits = self.find_stopped_splits()
for split in existing_split_list:
if split not in stopped_splits:
performance_dir_list = [
performance_dir
for performance_dir in list(
(self.maps_path / f"{self.split_name}-{split}").iterdir()
)
if "best-" in performance_dir.name
]
if len(performance_dir_list) > 0:
finished_splits.append(split)
return finished_splits

def _ensemble_prediction(
self,
Expand Down
Loading

0 comments on commit 9bcd8b2

Please sign in to comment.