Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review optim module #656

Merged
merged 10 commits into from
Sep 24, 2024
Merged
3 changes: 3 additions & 0 deletions clinicadl/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .config import OptimizationConfig
from .early_stopping import EarlyStopping
from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler
from .optimizer import create_optimizer_config, get_optimizer
3 changes: 2 additions & 1 deletion clinicadl/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ImplementedLRScheduler, LRSchedulerConfig
from .config import create_lr_scheduler_config
from .enum import ImplementedLRScheduler
from .factory import get_lr_scheduler
192 changes: 119 additions & 73 deletions clinicadl/optim/lr_scheduler/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from __future__ import annotations

from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

from pydantic import (
BaseModel,
Expand All @@ -10,71 +7,43 @@
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
field_validator,
model_validator,
)

from clinicadl.utils.factories import DefaultFromLibrary

from .enum import ImplementedLRScheduler, Mode, ThresholdMode

class ImplementedLRScheduler(str, Enum):
"""Implemented LR schedulers in ClinicaDL."""

CONSTANT = "ConstantLR"
LINEAR = "LinearLR"
STEP = "StepLR"
MULTI_STEP = "MultiStepLR"
PLATEAU = "ReduceLROnPlateau"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented LR schedulers are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Mode(str, Enum):
"""Supported mode for ReduceLROnPlateau."""

MIN = "min"
MAX = "max"


class ThresholdMode(str, Enum):
"""Supported threshold mode for ReduceLROnPlateau."""

ABS = "abs"
REL = "rel"
__all__ = [
"LRSchedulerConfig",
"ConstantLRConfig",
"LinearLRConfig",
"StepLRConfig",
"MultiStepLRConfig",
"ReduceLROnPlateauConfig",
"create_lr_scheduler_config",
]


class LRSchedulerConfig(BaseModel):
"""Config class to configure the optimizer."""
"""Base config class for the LR scheduler."""

scheduler: Optional[ImplementedLRScheduler] = None
step_size: Optional[PositiveInt] = None
gamma: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
milestones: Optional[List[PositiveInt]] = None
factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
total_iters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
last_epoch: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES

mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES
patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES
cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
min_lr: Union[
NonNegativeFloat, Dict[str, PositiveFloat], DefaultFromLibrary
] = DefaultFromLibrary.YES
eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
def scheduler(self) -> Optional[ImplementedLRScheduler]:
"""The name of the scheduler."""
return None

@field_validator("last_epoch")
@classmethod
def validator_last_epoch(cls, v):
Expand All @@ -84,31 +53,108 @@ def validator_last_epoch(cls, v):
), f"last_epoch must be -1 or a non-negative int but it has been set to {v}."
return v

@field_validator("milestones")

class ConstantLRConfig(LRSchedulerConfig):
"""Config class for ConstantLR scheduler."""

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.CONSTANT


class LinearLRConfig(LRSchedulerConfig):
"""Config class for LinearLR scheduler."""

start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.LINEAR


class StepLRConfig(LRSchedulerConfig):
"""Config class for StepLR scheduler."""

step_size: PositiveInt

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.STEP


class MultiStepLRConfig(LRSchedulerConfig):
"""Config class for MultiStepLR scheduler."""

milestones: List[PositiveInt]

@computed_field
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.MULTI_STEP

@field_validator("milestones", mode="after")
@classmethod
def validator_milestones(cls, v):
import numpy as np

if v is not None:
assert len(np.unique(v)) == len(
v
), "Epoch(s) in milestones should be unique."
return sorted(v)
return v
assert len(np.unique(v)) == len(v), "Epoch(s) in milestones should be unique."
return sorted(v)


class ReduceLROnPlateauConfig(LRSchedulerConfig):
"""Config class for ReduceLROnPlateau scheduler."""

mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES
patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES
cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
min_lr: Union[
NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary
] = DefaultFromLibrary.YES
eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@model_validator(mode="after")
def check_mandatory_args(self) -> LRSchedulerConfig:
if (
self.scheduler == ImplementedLRScheduler.MULTI_STEP
and self.milestones is None
):
raise ValueError(
"""If you chose MultiStepLR as LR scheduler, you should pass milestones
(see PyTorch documentation for more details)."""
)
elif self.scheduler == ImplementedLRScheduler.STEP and self.step_size is None:
raise ValueError(
"""If you chose StepLR as LR scheduler, you should pass a step_size
(see PyTorch documentation for more details)."""
)
return self
@property
def scheduler(self) -> ImplementedLRScheduler:
"""The name of the scheduler."""
return ImplementedLRScheduler.PLATEAU


def create_lr_scheduler_config(
scheduler: Optional[Union[str, ImplementedLRScheduler]],
) -> Type[LRSchedulerConfig]:
"""
A factory function to create a config class suited for the LR scheduler.

Parameters
----------
scheduler : Optional[Union[str, ImplementedLRScheduler]]
The name of the LR scheduler.
Can be None if no LR scheduler will be used.

Returns
-------
Type[LRSchedulerConfig]
The config class.

Raises
------
ValueError
If `scheduler` is not supported.
"""
if scheduler is None:
return LRSchedulerConfig

scheduler = ImplementedLRScheduler(scheduler)
config_name = "".join([scheduler, "Config"])
config = globals()[config_name]

return config
32 changes: 32 additions & 0 deletions clinicadl/optim/lr_scheduler/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from enum import Enum


class ImplementedLRScheduler(str, Enum):
"""Implemented LR schedulers in ClinicaDL."""

CONSTANT = "ConstantLR"
LINEAR = "LinearLR"
STEP = "StepLR"
MULTI_STEP = "MultiStepLR"
PLATEAU = "ReduceLROnPlateau"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented LR schedulers are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Mode(str, Enum):
"""Supported mode for ReduceLROnPlateau."""

MIN = "min"
MAX = "max"


class ThresholdMode(str, Enum):
"""Supported threshold mode for ReduceLROnPlateau."""

ABS = "abs"
REL = "rel"
2 changes: 1 addition & 1 deletion clinicadl/optim/lr_scheduler/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def get_lr_scheduler(
) # ELSE must be the last group
scheduler = scheduler_class(optimizer, **config_dict_)

updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict)
updated_config = config.model_copy(update=config_dict)

return scheduler, updated_config
3 changes: 2 additions & 1 deletion clinicadl/optim/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ImplementedOptimizer, OptimizerConfig
from .config import create_optimizer_config
from .enum import ImplementedOptimizer
from .factory import get_optimizer
Loading
Loading