Skip to content

Commit

Permalink
Creation of the trainer (aramis-lab#559)
Browse files Browse the repository at this point in the history
* creation of the trainer

* remove trainer's methods from MAPSManager

* creation of the trainer

* introduce trainer in ClinicaDL's train and resume functions

* small improvements in docstrings

* omission

* other omissions
  • Loading branch information
thibaultdvx authored and camillebrianceau committed May 30, 2024
1 parent 2f304cf commit 595ad8a
Show file tree
Hide file tree
Showing 5 changed files with 1,446 additions and 1,228 deletions.
6 changes: 4 additions & 2 deletions clinicadl/train/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

from clinicadl import MapsManager
from clinicadl.utils.maps_manager.trainer import Trainer


def replace_arg(options, key_name, value):
Expand All @@ -19,6 +20,7 @@ def automatic_resume(model_path: Path, user_split_list=None, verbose=0):

verbose_list = ["warning", "info", "debug"]
maps_manager = MapsManager(model_path, verbose=verbose_list[verbose])
trainer = Trainer(maps_manager)

existing_split_list = maps_manager._find_splits()
stopped_splits = [
Expand Down Expand Up @@ -58,6 +60,6 @@ def automatic_resume(model_path: Path, user_split_list=None, verbose=0):
f"Absent splits {absent_splits}"
)
if len(stopped_splits) > 0:
maps_manager.resume(stopped_splits)
trainer.resume(stopped_splits)
if len(absent_splits) > 0:
maps_manager.train(absent_splits, overwrite=True)
trainer.train(absent_splits, overwrite=True)
4 changes: 3 additions & 1 deletion clinicadl/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List

from clinicadl import MapsManager
from clinicadl.utils.maps_manager.trainer import Trainer


def train(
Expand All @@ -12,4 +13,5 @@ def train(
erase_existing: bool = True,
):
maps_manager = MapsManager(maps_dir, train_dict, verbose=None)
maps_manager.train(split_list=split_list, overwrite=erase_existing)
trainer = Trainer(maps_manager)
trainer.train(split_list=split_list, overwrite=erase_existing)
Loading

0 comments on commit 595ad8a

Please sign in to comment.