Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put resume in trainer #612

Merged
merged 37 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a9a23b9
Bump sqlparse from 0.4.4 to 0.5.0 (#558)
dependabot[bot] Apr 22, 2024
36eb46f
Bump tqdm from 4.66.1 to 4.66.3 (#569)
dependabot[bot] May 4, 2024
fa7f0f1
Bump werkzeug from 3.0.1 to 3.0.3 (#570)
dependabot[bot] May 7, 2024
a05fcd5
Bump jinja2 from 3.1.3 to 3.1.4 (#571)
dependabot[bot] May 7, 2024
b2fc3e6
Bump mlflow from 2.10.1 to 2.12.1 (#575)
dependabot[bot] May 17, 2024
495d5b9
Bump gunicorn from 21.2.0 to 22.0.0 (#576)
dependabot[bot] May 17, 2024
bdd102a
Bump requests from 2.31.0 to 2.32.0 (#578)
dependabot[bot] May 21, 2024
beccd4c
[CI] Run tests through GitHub Actions (#573)
NicolasGensollen May 22, 2024
2861e9d
[CI] Skip tests when PR is in draft mode (#592)
NicolasGensollen May 23, 2024
f5de251
[CI] Test train workflow on GPU machine (#590)
NicolasGensollen May 23, 2024
69b3538
[CI] Port remaining GPU tests to GitHub Actions (#593)
NicolasGensollen May 23, 2024
c9d9252
[CI] Remove GPU pipeline from Jenkinsfile (#594)
NicolasGensollen May 24, 2024
753f04e
[CI] Port remaining non GPU tests to GitHub Actions (#581)
NicolasGensollen May 24, 2024
c424d77
[CI] Remove jenkins related things (#595)
NicolasGensollen May 24, 2024
4281c73
add simulate-gpu option
thibaultdvx May 28, 2024
52d7561
Add flags to run CI tests locally (#596)
thibaultdvx May 30, 2024
39d22fd
[CI] Remove duplicated verbose flag in test pipelines (#598)
NicolasGensollen May 30, 2024
571662c
[DOC] Update the Python version used for creating the conda environme…
NicolasGensollen May 30, 2024
567467e
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx May 30, 2024
d54d59c
Flag for local tests (#608)
thibaultdvx May 31, 2024
f641f30
add whole resume pipeline in trainer
thibaultdvx May 31, 2024
78f2928
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx May 31, 2024
a6c336f
change from_json cli
thibaultdvx Jun 4, 2024
f20e7fb
Update quality_check.py (#609)
HuguesRoy Jun 4, 2024
f6f382a
Fix issue in compare_folders (#610)
thibaultdvx Jun 4, 2024
cd3a538
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx Jun 4, 2024
f7eb225
Merge branch 'dev' into refactoring
thibaultdvx Jun 4, 2024
523563d
revert change on poetry
thibaultdvx Jun 4, 2024
4971fa7
correction of wrong conflict choice in rebasing
thibaultdvx Jun 4, 2024
c60d53c
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jun 4, 2024
c1b6e5b
Merge branch 'refactoring' into put_resume_in_trainer
thibaultdvx Jun 4, 2024
fdae3dd
restore split_manager
thibaultdvx Jun 4, 2024
9a57198
trigger tests
thibaultdvx Jun 4, 2024
375e67e
delete automatic resume function
thibaultdvx Jun 5, 2024
919f930
change find_splits to _find_splits
thibaultdvx Jun 5, 2024
02c4e30
Merge branch 'refactoring' of https://github.com/aramis-lab/clinicadl…
thibaultdvx Jun 6, 2024
34625fc
Merge branch 'refactoring' into put_resume_in_trainer
thibaultdvx Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions clinicadl/commandline/pipelines/train/from_json/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from logging import getLogger
from pathlib import Path

import click

from clinicadl.commandline import arguments
from clinicadl.commandline.modules_options import (
cross_validation,
reproducibility,
)
from clinicadl.train.tasks_utils import create_training_config
from clinicadl.trainer.trainer import Trainer


@click.command(name="from_json", no_args_is_help=True)
Expand All @@ -24,23 +22,11 @@ def cli(**kwargs):

OUTPUT_MAPS_DIRECTORY is the path to the MAPS folder where outputs and results will be saved.
"""
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.maps_manager.maps_manager_utils import read_json

logger = getLogger("clinicadl")
logger.info(f"Reading JSON file at path {kwargs['config_file']}...")
config_dict = read_json(kwargs["config_file"])
# temporary
config_dict["tsv_directory"] = config_dict["tsv_path"]
if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""):
config_dict["track_exp"] = None
config_dict["maps_dir"] = kwargs["output_maps_directory"]
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
###
config = create_training_config(config_dict["network_task"])(
output_maps_directory=kwargs["output_maps_directory"], **config_dict

trainer = Trainer.from_json(
config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"]
)
trainer = Trainer(config)
trainer.train(split_list=kwargs["split"], overwrite=True)
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/train/resume/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from clinicadl.commandline.modules_options import (
cross_validation,
)
from clinicadl.trainer import Trainer


@click.command(name="resume", no_args_is_help=True)
Expand All @@ -14,6 +15,5 @@ def cli(input_maps_directory, split):

INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started.
"""
from clinicadl.train.resume import automatic_resume

automatic_resume(input_maps_directory, user_split_list=split)
trainer = Trainer.from_maps(input_maps_directory)
trainer.resume(split)
2 changes: 1 addition & 1 deletion clinicadl/config/config/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ def validator_split(cls, v):
def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.split:
self.split = maps_manager._find_splits()
self.split = maps_manager.find_splits()
logger.debug(f"List of splits {self.split}")
2 changes: 1 addition & 1 deletion clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def _check_data_group(
raise MAPSError("Cannot overwrite train or validation data group.")
else:
# if not split_list:
# split_list = self.maps_manager._find_splits()
# split_list = self.maps_manager.find_splits()
assert self._config.split
for split in self._config.split:
selection_metrics = self.maps_manager._find_selection_metrics(
Expand Down
82 changes: 0 additions & 82 deletions clinicadl/train/resume.py

This file was deleted.

124 changes: 110 additions & 14 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from clinicadl.transforms.transforms import get_transforms
from clinicadl.caps_dataset.data_utils import return_dataset
from clinicadl.utils.early_stopping import EarlyStopping
from clinicadl.utils.exceptions import MAPSError
from clinicadl.utils.maps_manager.ddp import DDP, cluster
from clinicadl.utils.maps_manager.logwriter import LogWriter
from clinicadl.utils.maps_manager.maps_manager_utils import read_json
from clinicadl.utils.metric_module import RetainBest
from clinicadl.utils.seed import pl_worker_init_function, seed_everything
from clinicadl.transforms.transforms import get_transforms
from clinicadl.utils.maps_manager import MapsManager
from clinicadl.utils.seed import get_seed

from clinicadl.utils.enum import Task
from .trainer_utils import create_parameters_dict
from .trainer_utils import create_parameters_dict, patch_to_read_json
from clinicadl.train.tasks_utils import create_training_config

if TYPE_CHECKING:
from clinicadl.callbacks.callbacks import Callback
Expand All @@ -42,33 +43,128 @@ class Trainer:
def __init__(
self,
config: TrainConfig,
maps_manager: Optional[MapsManager] = None,
) -> None:
"""
Parameters
----------
config : BaseTaskConfig
config : TrainConfig
"""
self.config = config
if maps_manager:
self.maps_manager = maps_manager
else:
self.maps_manager = self._init_maps_manager(config)
self.maps_manager = self._init_maps_manager(config)
self._check_args()

def _init_maps_manager(self, config) -> MapsManager:
# temporary: to match CLI data. TODO : change CLI data

parameters, maps_path = create_parameters_dict(config)
return MapsManager(
maps_path, parameters, verbose=None
) # TODO : precise which parameters in config are useful
if maps_path.is_dir():
return MapsManager(
maps_path, verbose=None
) # TODO : precise which parameters in config are useful
else:
return MapsManager(
maps_path, parameters, verbose=None
) # TODO : precise which parameters in config are useful

@classmethod
def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer:
"""
Creates a Trainer from a json configuration file.

Parameters
----------
config_file : str | Path
The parameters, stored in a json files.
maps_path : str | Path
The folder where the results of a futur training will be stored.

Returns
-------
Trainer
The Trainer object, instantiated with parameters found in config_file.

Raises
------
FileNotFoundError
If config_file doesn't exist.
"""
config_file = Path(config_file)

if not (config_file).is_file():
raise FileNotFoundError(f"No file found at {str(config_file)}.")
config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch
config_dict["maps_dir"] = maps_path
config_object = create_training_config(config_dict["network_task"])(
**config_dict
)
return cls(config_object)

@classmethod
def from_maps(cls, maps_path: str | Path) -> Trainer:
"""
Creates a Trainer from a json configuration file.

Parameters
----------
maps_path : str | Path
The path of the MAPS folder.

Returns
-------
Trainer
The Trainer object, instantiated with parameters found in maps_path.

Raises
------
MAPSError
If maps_path folder doesn't exist or there is no maps.json file in it.
"""
maps_path = Path(maps_path)

if not (maps_path / "maps.json").is_file():
raise MAPSError(
f"MAPS was not found at {str(maps_path)}."
f"To initiate a new MAPS please give a train_dict."
)
return cls.from_json(maps_path / "maps.json", maps_path)

def resume(self, splits: List[int]) -> None:
"""
Resume a prematurely stopped training.

Parameters
----------
splits : List[int]
The splits that must be resumed.
"""
stopped_splits = set(self.maps_manager.find_stopped_splits())
finished_splits = set(self.maps_manager.find_finished_splits())
# TODO : check these two lines. Why do we need a split_manager?
split_manager = self.maps_manager._init_split_manager(split_list=splits)
split_iterator = split_manager.split_iterator()
###
absent_splits = set(split_iterator) - stopped_splits - finished_splits

logger.info(
f"Finished splits {finished_splits}\n"
f"Stopped splits {stopped_splits}\n"
f"Absent splits {absent_splits}"
)

if len(stopped_splits) == 0 and len(absent_splits) == 0:
raise ValueError(
"Training has been completed on all the splits you passed."
)
if len(stopped_splits) > 0:
self._resume(list(stopped_splits))
if len(absent_splits) > 0:
self.train(list(absent_splits), overwrite=True)

def _check_args(self):
self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed)
# if (len(self.config.data.label_code) == 0):
# self.config.data.label_code = self.maps_manager.label_code
# TODO : deal with label_code and replace self.maps_manager.label_code
# TODO: deal with label_code and replace self.maps_manager.label_code

def train(
self,
Expand Down Expand Up @@ -120,7 +216,7 @@ def train(
else:
self._train_single(split_list, resume=False)

def resume(
def _resume(
self,
split_list: Optional[List[int]] = None,
) -> None:
Expand Down
16 changes: 16 additions & 0 deletions clinicadl/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,19 @@ def create_parameters_dict(config):
del parameters["mask_path"]

return parameters, maps_path


def patch_to_read_json(config_dict):
config_dict["tsv_directory"] = config_dict["tsv_path"]
if ("track_exp" in config_dict) and (config_dict["track_exp"] == ""):
config_dict["track_exp"] = None
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
if "label_code" not in config_dict or config_dict["label_code"] is None:
config_dict["label_code"] = {}
if "preprocessing_json" not in config_dict:
config_dict["preprocessing_json"] = config_dict["preprocessing_dict"][
"extract_json"
]
return config_dict
36 changes: 33 additions & 3 deletions clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,43 @@ def _compute_output_tensors(
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def _find_splits(self):
"""Find which splits were trained in the MAPS."""
return [
def find_splits(self) -> List[int]:
"""Find which splits that were trained in the MAPS."""
splits = [
int(split.name.split("-")[1])
for split in list(self.maps_path.iterdir())
if split.name.startswith(f"{self.split_name}-")
]
return splits

def find_stopped_splits(self) -> List[int]:
"""Find which splits for which training was not completed."""
existing_split_list = self.find_splits()
stopped_splits = [
split
for split in existing_split_list
if (self.maps_path / f"{self.split_name}-{split}" / "tmp")
in list((self.maps_path / f"{self.split_name}-{split}").iterdir())
]
return stopped_splits

def find_finished_splits(self) -> List[int]:
"""Find which splits for which training was completed."""
finished_splits = list()
existing_split_list = self.find_splits()
stopped_splits = self.find_stopped_splits()
for split in existing_split_list:
if split not in stopped_splits:
performance_dir_list = [
performance_dir
for performance_dir in list(
(self.maps_path / f"{self.split_name}-{split}").iterdir()
)
if "best-" in performance_dir.name
]
if len(performance_dir_list) > 0:
finished_splits.append(split)
return finished_splits

def _ensemble_prediction(
self,
Expand Down
Loading