Skip to content

Commit

Permalink
Trainer config (aramis-lab#561)
Browse files Browse the repository at this point in the history
* get trainer out of mapsmanager folder

* base training config class

* task specific config classes

* unit test for config classes

* changes in cli to have default values from task config objects

* ranem and simplify build_train_dict

* unit test for train_utils

* small modification in training config toml template

* rename build_train_dict in the other parts of the project

* modify task_launcher to use config objects

* Bump sqlparse from 0.4.4 to 0.5.0 (aramis-lab#558)

Bumps [sqlparse](https://github.com/andialbrecht/sqlparse) from 0.4.4 to 0.5.0.
- [Changelog](https://github.com/andialbrecht/sqlparse/blob/master/CHANGELOG)
- [Commits](andialbrecht/sqlparse@0.4.4...0.5.0)

---
updated-dependencies:
- dependency-name: sqlparse
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* typo

* change _network_task attribute

* omissions

* patches  to match CLI data

* small modifications

* correction

* correction reconstruction default loss

* add architecture command specific to the task

* add use_extracted_features parameter

* add VAE parameters in reconstruction

* add condition on whether cli arg is default or from user

* correct wrong import in resume

* validators on assignment

* reformat

* replace literal with enum

* review on CLI options

* convert enum to str for train function

* correct track exp issue

* test for ci

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and camillebrianceau committed May 30, 2024
1 parent 595ad8a commit 61c00d2
Show file tree
Hide file tree
Showing 24 changed files with 1,216 additions and 331 deletions.
4 changes: 2 additions & 2 deletions clinicadl/random_search/random_search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import toml

from clinicadl.train.train_utils import build_train_dict
from clinicadl.train.train_utils import extract_config_from_toml_file
from clinicadl.utils.exceptions import ClinicaDLConfigurationError
from clinicadl.utils.preprocessing import path_decoder, read_preprocessing

Expand Down Expand Up @@ -49,7 +49,7 @@ def get_space_dict(launch_directory: Path) -> Dict[str, Any]:
space_dict.setdefault("n_conv", 1)
space_dict.setdefault("wd_bool", True)

train_default = build_train_dict(toml_path, space_dict["network_task"])
train_default = extract_config_from_toml_file(toml_path, space_dict["network_task"])

# Mode and preprocessing
preprocessing_json = (
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ diagnoses = ["AD", "CN"]
baseline = false
valid_longitudinal = false
normalize = true
data_augmentation = false
data_augmentation = []
sampler = "random"
size_reduction=false
size_reduction_factor=2
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/train/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path

from clinicadl import MapsManager
from clinicadl.utils.maps_manager.trainer import Trainer
from clinicadl.utils.trainer import Trainer


def replace_arg(options, key_name, value):
Expand Down
194 changes: 194 additions & 0 deletions clinicadl/train/tasks/base_training_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from enum import Enum
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, PrivateAttr, field_validator

logger = getLogger("clinicadl.base_training_config")


class Compensation(str, Enum):
"""Available compensations in clinicaDL."""

MEMORY = "memory"
TIME = "time"


class SizeReductionFactor(int, Enum):
"""Available size reduction factors in ClinicaDL."""

TWO = 2
THREE = 3
FOUR = 4
FIVE = 5


class ExperimentTracking(str, Enum):
"""Available tools for experiment tracking in ClinicaDL."""

MLFLOW = "mlflow"
WANDB = "wandb"


class Sampler(str, Enum):
"""Available samplers in ClinicaDL."""

RANDOM = "random"
WEIGHTED = "weighted"


class Mode(str, Enum):
"""Available modes in ClinicaDL."""

IMAGE = "image"
PATCH = "patch"
ROI = "roi"
SLICE = "slice"


class BaseTaskConfig(BaseModel):
"""
Base class to handle parameters of the training pipeline.
"""

caps_directory: Path
preprocessing_json: Path
tsv_directory: Path
output_maps_directory: Path
# Computational
gpu: bool = True
n_proc: int = 2
batch_size: int = 8
evaluation_steps: int = 0
fully_sharded_data_parallel: bool = False
amp: bool = False
# Reproducibility
seed: int = 0
deterministic: bool = False
compensation: Compensation = Compensation.MEMORY
save_all_models: bool = False
track_exp: Optional[ExperimentTracking] = None
# Model
multi_network: bool = False
ssda_network: bool = False
# Data
multi_cohort: bool = False
diagnoses: Tuple[str, ...] = ("AD", "CN")
baseline: bool = False
valid_longitudinal: bool = False
normalize: bool = True
data_augmentation: Tuple[str, ...] = ()
sampler: Sampler = Sampler.RANDOM
size_reduction: bool = False
size_reduction_factor: SizeReductionFactor = (
SizeReductionFactor.TWO
) # TODO : change to optional and remove size_reduction parameter
caps_target: Path = Path("")
tsv_target_lab: Path = Path("")
tsv_target_unlab: Path = Path("")
preprocessing_dict_target: Path = Path(
""
) ## TODO : change name in commandline. preprocessing_json_target?
# Cross validation
n_splits: int = 0
split: Tuple[int, ...] = ()
# Optimization
optimizer: str = "Adam"
epochs: int = 20
learning_rate: float = 1e-4
adaptive_learning_rate: bool = False
weight_decay: float = 1e-4
dropout: float = 0.0
patience: int = 0
tolerance: float = 0.0
accumulation_steps: int = 1
profiler: bool = False
# Transfer Learning
transfer_path: Optional[Path] = None
transfer_selection_metric: str = "loss"
nb_unfrozen_layer: int = 0
# Information
emissions_calculator: bool = False
# Mode
use_extracted_features: bool = False # unused. TODO : remove
# Private
_preprocessing_dict: Dict[str, Any] = PrivateAttr()
_preprocessing_dict_target: Dict[str, Any] = PrivateAttr()
_mode: Mode = PrivateAttr()

class ConfigDict:
validate_assignment = True

@field_validator("diagnoses", "split", "data_augmentation", mode="before")
def list_to_tuples(cls, v):
if isinstance(v, list):
return tuple(v)
return v

@field_validator("transfer_path", mode="before")
def false_to_none(cls, v):
if v is False:
return None
return v

@classmethod
def get_available_optimizers(cls) -> List[str]:
"""To get the list of available optimizers."""
available_optimizers = [ # TODO : connect to PyTorch to have available optimizers
"Adadelta",
"Adagrad",
"Adam",
"AdamW",
"Adamax",
"ASGD",
"NAdam",
"RAdam",
"RMSprop",
"SGD",
]
return available_optimizers

@field_validator("optimizer")
def validator_optimizer(cls, v):
available_optimizers = cls.get_available_optimizers()
assert (
v in available_optimizers
), f"Optimizer '{v}' not supported. Please choose among: {available_optimizers}"
return v

@classmethod
def get_available_transforms(cls) -> List[str]:
"""To get the list of available transforms."""
available_transforms = [ # TODO : connect to transforms module
"Noise",
"Erasing",
"CropPad",
"Smoothing",
"Motion",
"Ghosting",
"Spike",
"BiasField",
"RandomBlur",
"RandomSwap",
]
return available_transforms

@field_validator("data_augmentation", mode="before")
def validator_data_augmentation(cls, v):
if v is False:
return ()

available_transforms = cls.get_available_transforms()
for transform in v:
assert (
transform in available_transforms
), f"Transform '{transform}' not supported. Please pick among: {available_transforms}"
return v

@field_validator("dropout")
def validator_dropout(cls, v):
assert (
0 <= v <= 1
), f"dropout must be between 0 and 1 but it has been set to {v}."
return v
33 changes: 19 additions & 14 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from pathlib import Path

import click
from click.core import ParameterSource

from clinicadl.train.train_utils import extract_config_from_toml_file
from clinicadl.utils.cli_param import train_option

from .classification_config import ClassificationConfig
from .task_utils import task_launcher


Expand All @@ -26,7 +31,7 @@
@train_option.compensation
@train_option.save_all_models
# Model
@train_option.architecture
@train_option.classification_architecture
@train_option.multi_network
@train_option.ssda_network
# Data
Expand Down Expand Up @@ -61,8 +66,8 @@
@train_option.transfer_selection_metric
@train_option.nb_unfrozen_layer
# Task-related
@train_option.label
@train_option.selection_metrics
@train_option.classification_label
@train_option.classification_selection_metrics
@train_option.selection_threshold
@train_option.classification_loss
# information
Expand All @@ -84,14 +89,14 @@ def cli(**kwargs):
configuration file in TOML format. For more details, please visit the documentation:
https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file
"""
task_specific_options = [
"label",
"selection_metrics",
"selection_threshold",
"loss",
]
task_launcher("classification", task_specific_options, **kwargs)


if __name__ == "__main__":
cli()
options = {}
if kwargs["config_file"]:
options = extract_config_from_toml_file(
Path(kwargs["config_file"]),
"classification",
)
for arg in kwargs:
if click.get_current_context().get_parameter_source(arg) == ParameterSource.COMMANDLINE:
options[arg] = kwargs[arg]
config = ClassificationConfig(**options)
task_launcher(config)
55 changes: 55 additions & 0 deletions clinicadl/train/tasks/classification_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from logging import getLogger
from typing import Dict, List, Tuple

from pydantic import PrivateAttr, field_validator

from .base_training_config import BaseTaskConfig

logger = getLogger("clinicadl.classification_config")


class ClassificationConfig(BaseTaskConfig):
"""Config class to handle parameters of the classification task."""

architecture: str = "Conv5_FC3"
loss: str = "CrossEntropyLoss"
label: str = "diagnosis"
label_code: Dict[str, int] = {}
selection_threshold: float = 0.0
selection_metrics: Tuple[str, ...] = ("loss",)
# private
_network_task: str = PrivateAttr(default="classification")

@field_validator("selection_metrics", mode="before")
def list_to_tuples(cls, v):
if isinstance(v, list):
return tuple(v)
return v

@classmethod
def get_compatible_losses(cls) -> List[str]:
"""To get the list of losses implemented and compatible with this task."""
compatible_losses = [ # TODO : connect to the Loss module
"CrossEntropyLoss",
"MultiMarginLoss",
]
return compatible_losses

@field_validator("loss")
def validator_loss(cls, v):
compatible_losses = cls.get_compatible_losses()
assert (
v in compatible_losses
), f"Loss '{v}' can't be used for this task. Please choose among: {compatible_losses}"
return v

@field_validator("selection_threshold")
def validator_threshold(cls, v):
assert (
0 <= v <= 1
), f"selection_threshold must be between 0 and 1 but it has been set to {v}."
return v

@field_validator("architecture")
def validator_architecture(cls, v):
return v # TODO : connect to network module to have list of available architectures
22 changes: 18 additions & 4 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from pathlib import Path

import click
from click.core import ParameterSource

from clinicadl.train.train_utils import extract_config_from_toml_file
from clinicadl.utils.cli_param import train_option

from .reconstruction_config import ReconstructionConfig
from .task_utils import task_launcher


Expand All @@ -26,7 +31,7 @@
@train_option.compensation
@train_option.save_all_models
# Model
@train_option.architecture
@train_option.reconstruction_architecture
@train_option.multi_network
@train_option.ssda_network
# Data
Expand Down Expand Up @@ -61,7 +66,7 @@
@train_option.transfer_selection_metric
@train_option.nb_unfrozen_layer
# Task-related
@train_option.selection_metrics
@train_option.reconstruction_selection_metrics
@train_option.reconstruction_loss
# information
@train_option.emissions_calculator
Expand All @@ -82,5 +87,14 @@ def cli(**kwargs):
configuration file in TOML format. For more details, please visit the documentation:
https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file
"""
task_specific_options = ["selection_metrics", "loss"]
task_launcher("reconstruction", task_specific_options, **kwargs)
options = {}
if kwargs["config_file"]:
options = extract_config_from_toml_file(
Path(kwargs["config_file"]),
"reconstruction",
)
for arg in kwargs:
if click.get_current_context().get_parameter_source(arg) == ParameterSource.COMMANDLINE:
options[arg] = kwargs[arg]
config = ReconstructionConfig(**options)
task_launcher(config)
Loading

0 comments on commit 61c00d2

Please sign in to comment.