Skip to content

Commit

Permalink
Review optim module (#656)
Browse files Browse the repository at this point in the history
* create a config class for all optimizers

* config for all schedulers
  • Loading branch information
thibaultdvx authored Sep 24, 2024
1 parent 0f20f7b commit 048ca39
Show file tree
Hide file tree
Showing 13 changed files with 719 additions and 226 deletions.
3 changes: 3 additions & 0 deletions clinicadl/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .config import OptimizationConfig
from .early_stopping import EarlyStopping
from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler
from .optimizer import create_optimizer_config, get_optimizer
3 changes: 2 additions & 1 deletion clinicadl/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ImplementedLRScheduler, LRSchedulerConfig
from .config import create_lr_scheduler_config
from .enum import ImplementedLRScheduler
from .factory import get_lr_scheduler
192 changes: 119 additions & 73 deletions clinicadl/optim/lr_scheduler/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from __future__ import annotations

from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

from pydantic import (
BaseModel,
Expand All @@ -10,71 +7,43 @@
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
field_validator,
model_validator,
)

from clinicadl.utils.factories import DefaultFromLibrary

from .enum import ImplementedLRScheduler, Mode, ThresholdMode

class ImplementedLRScheduler(str, Enum):
"""Implemented LR schedulers in ClinicaDL."""

CONSTANT = "ConstantLR"
LINEAR = "LinearLR"
STEP = "StepLR"
MULTI_STEP = "MultiStepLR"
PLATEAU = "ReduceLROnPlateau"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented LR schedulers are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Mode(str, Enum):
"""Supported mode for ReduceLROnPlateau."""

MIN = "min"
MAX = "max"


class ThresholdMode(str, Enum):
"""Supported threshold mode for ReduceLROnPlateau."""

ABS = "abs"
REL = "rel"
__all__ = [
"LRSchedulerConfig",
"ConstantLRConfig",
"LinearLRConfig",
"StepLRConfig",
"MultiStepLRConfig",
"ReduceLROnPlateauConfig",
"create_lr_scheduler_config",
]


class LRSchedulerConfig(BaseModel):
"""Config class to configure the optimizer."""
"""Base config class for the LR scheduler."""

scheduler: Optional[ImplementedLRScheduler] = None
step_size: Optional[PositiveInt] = None
gamma: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
milestones: Optional[List[PositiveInt]] = None
factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
total_iters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
last_epoch: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES

mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES
patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES
cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
min_lr: Union[
NonNegativeFloat, Dict[str, PositiveFloat], DefaultFromLibrary
] = DefaultFromLibrary.YES
eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
def scheduler(self) -> Optional[ImplementedLRScheduler]:
"""The name of the scheduler."""
return None

@field_validator("last_epoch")
@classmethod
def validator_last_epoch(cls, v):
Expand All @@ -84,31 +53,108 @@ def validator_last_epoch(cls, v):
), f"last_epoch must be -1 or a non-negative int but it has been set to {v}."
return v

@field_validator("milestones")

class ConstantLRConfig(LRSchedulerConfig):
"""Config class for ConstantLR scheduler."""

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.CONSTANT


class LinearLRConfig(LRSchedulerConfig):
"""Config class for LinearLR scheduler."""

start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.LINEAR


class StepLRConfig(LRSchedulerConfig):
"""Config class for StepLR scheduler."""

step_size: PositiveInt

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.STEP


class MultiStepLRConfig(LRSchedulerConfig):
"""Config class for MultiStepLR scheduler."""

milestones: List[PositiveInt]

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.MULTI_STEP

@field_validator("milestones", mode="after")
@classmethod
def validator_milestones(cls, v):
import numpy as np

if v is not None:
assert len(np.unique(v)) == len(
v
), "Epoch(s) in milestones should be unique."
return sorted(v)
return v
assert len(np.unique(v)) == len(v), "Epoch(s) in milestones should be unique."
return sorted(v)


class ReduceLROnPlateauConfig(LRSchedulerConfig):
"""Config class for ReduceLROnPlateau scheduler."""

mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES
patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES
cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
min_lr: Union[
NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary
] = DefaultFromLibrary.YES
eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@model_validator(mode="after")
def check_mandatory_args(self) -> LRSchedulerConfig:
if (
self.scheduler == ImplementedLRScheduler.MULTI_STEP
and self.milestones is None
):
raise ValueError(
"""If you chose MultiStepLR as LR scheduler, you should pass milestones
(see PyTorch documentation for more details)."""
)
elif self.scheduler == ImplementedLRScheduler.STEP and self.step_size is None:
raise ValueError(
"""If you chose StepLR as LR scheduler, you should pass a step_size
(see PyTorch documentation for more details)."""
)
return self
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.PLATEAU


def create_lr_scheduler_config(
scheduler: Optional[Union[str, ImplementedLRScheduler]],
) -> Type[LRSchedulerConfig]:
"""
A factory function to create a config class suited for the LR scheduler.
Parameters
----------
scheduler : Optional[Union[str, ImplementedLRScheduler]]
The name of the LR scheduler.
Can be None if no LR scheduler will be used.
Returns
-------
Type[LRSchedulerConfig]
The config class.
Raises
------
ValueError
If `scheduler` is not supported.
"""
if scheduler is None:
return LRSchedulerConfig

scheduler = ImplementedLRScheduler(scheduler)
config_name = "".join([scheduler, "Config"])
config = globals()[config_name]

return config
32 changes: 32 additions & 0 deletions clinicadl/optim/lr_scheduler/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from enum import Enum


class ImplementedLRScheduler(str, Enum):
"""Implemented LR schedulers in ClinicaDL."""

CONSTANT = "ConstantLR"
LINEAR = "LinearLR"
STEP = "StepLR"
MULTI_STEP = "MultiStepLR"
PLATEAU = "ReduceLROnPlateau"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented LR schedulers are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Mode(str, Enum):
"""Supported mode for ReduceLROnPlateau."""

MIN = "min"
MAX = "max"


class ThresholdMode(str, Enum):
"""Supported threshold mode for ReduceLROnPlateau."""

ABS = "abs"
REL = "rel"
2 changes: 1 addition & 1 deletion clinicadl/optim/lr_scheduler/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ def get_lr_scheduler(
config_dict_["min_lr"].append(default_min_lr)
scheduler = scheduler_class(optimizer, **config_dict_)

updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict)
updated_config = config.model_copy(update=config_dict)

return scheduler, updated_config
3 changes: 2 additions & 1 deletion clinicadl/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ImplementedOptimizer, OptimizerConfig
from .config import create_optimizer_config
from .enum import ImplementedOptimizer
from .factory import get_optimizer
Loading

0 comments on commit 048ca39

Please sign in to comment.