From 381b005aca3955cb6aa2dfeea7d4eaa943627a7e Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 19 Sep 2024 17:45:16 +0200 Subject: [PATCH 1/9] create enum for optimizer --- clinicadl/optim/optimizer/enum.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 clinicadl/optim/optimizer/enum.py diff --git a/clinicadl/optim/optimizer/enum.py b/clinicadl/optim/optimizer/enum.py new file mode 100644 index 000000000..a397dbe85 --- /dev/null +++ b/clinicadl/optim/optimizer/enum.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class ImplementedOptimizer(str, Enum): + """Implemented optimizers in ClinicaDL.""" + + ADADELTA = "Adadelta" + ADAGRAD = "Adagrad" + ADAM = "Adam" + RMS_PROP = "RMSprop" + SGD = "SGD" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented optimizers are: " + + ", ".join([repr(m.value) for m in cls]) + ) From ff14c7a51eb4a6d29f09a8300fc4ce310e86afac Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 19 Sep 2024 17:45:48 +0200 Subject: [PATCH 2/9] create a config class for all optimizers --- clinicadl/optim/optimizer/config.py | 200 +++++++++++++++++++++------- 1 file changed, 150 insertions(+), 50 deletions(-) diff --git a/clinicadl/optim/optimizer/config.py b/clinicadl/optim/optimizer/config.py index 46aa5958c..b0a55f034 100644 --- a/clinicadl/optim/optimizer/config.py +++ b/clinicadl/optim/optimizer/config.py @@ -1,38 +1,32 @@ -from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Type, Union from pydantic import ( BaseModel, ConfigDict, NonNegativeFloat, PositiveFloat, + computed_field, field_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedOptimizer -class ImplementedOptimizer(str, Enum): - """Implemented optimizers in ClinicaDL.""" +__all__ = [ + "OptimizerConfig", + "AdadeltaConfig", + "AdagradConfig", + "AdamConfig", + "RMSpropConfig", + "SGDConfig", + "create_optimizer_config", +] - ADADELTA = "Adadelta" - ADAGRAD = "Adagrad" - ADAM = "Adam" - RMS_PROP = "RMSprop" - SGD = "SGD" - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented optimizers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class OptimizerConfig(BaseModel): - """Config class to configure the optimizer.""" - - optimizer: ImplementedOptimizer = ImplementedOptimizer.ADAM +class OptimizerConfig(BaseModel, ABC): + """Base config class for the optimizer.""" lr: Union[ PositiveFloat, Dict[str, PositiveFloat], DefaultFromLibrary @@ -40,36 +34,9 @@ class OptimizerConfig(BaseModel): weight_decay: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - betas: Union[ - Tuple[NonNegativeFloat, NonNegativeFloat], - Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - alpha: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - momentum: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - rho: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - lr_decay: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES eps: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - dampening: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - initial_accumulator_value: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES foreach: Union[ Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES @@ -81,14 +48,19 @@ class OptimizerConfig(BaseModel): bool, Dict[str, bool], DefaultFromLibrary ] = DefaultFromLibrary.YES fused: Union[ - Optional[bool], Dict[str, bool], DefaultFromLibrary + Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) - @field_validator("betas", "rho", "alpha", "dampening") + @computed_field + @property + @abstractmethod + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + @classmethod def validator_proba(cls, v, ctx): name = ctx.field_name @@ -128,3 +100,131 @@ def get_all_groups(self) -> List[str]: groups.update(set(value.keys())) return list(groups) + + +class AdadeltaConfig(OptimizerConfig): + """Config class for Adadelta optimizer.""" + + rho: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADADELTA + + @field_validator("rho") + def validator_rho(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class AdagradConfig(OptimizerConfig): + """Config class for Adagrad optimizer.""" + + lr_decay: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + initial_accumulator_value: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAGRAD + + +class AdamConfig(OptimizerConfig): + """Config class for Adam optimizer.""" + + betas: Union[ + Tuple[NonNegativeFloat, NonNegativeFloat], + Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], + DefaultFromLibrary, + ] = DefaultFromLibrary.YES + amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAM + + @field_validator("betas") + def validator_betas(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class RMSpropConfig(OptimizerConfig): + """Config class for RMSprop optimizer.""" + + alpha: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.RMS_PROP + + @field_validator("alpha") + def validator_alpha(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class SGDConfig(OptimizerConfig): + """Config class for SGD optimizer.""" + + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dampening: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.SGD + + @field_validator("dampening") + def validator_dampening(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +def create_optimizer_config( + optimizer: Union[str, ImplementedOptimizer], +) -> Type[OptimizerConfig]: + """ + A factory function to create a config class suited for the optimizer. + + Parameters + ---------- + optimizer : Union[str, ImplementedOptimizer] + The name of the optimizer. + + Returns + ------- + Type[OptimizerConfig] + The config class. + + Raises + ------ + ValueError + If `optimizer` is not supported. + """ + optimizer = ImplementedOptimizer(optimizer) + config_name = "".join([optimizer, "Config"]) + config = globals()[config_name] + + return config From cf7b404eb5e2982926a9979404a218d6f2c53968 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 19 Sep 2024 17:46:09 +0200 Subject: [PATCH 3/9] fix optimizer factory --- clinicadl/optim/optimizer/__init__.py | 3 ++- clinicadl/optim/optimizer/factory.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/clinicadl/optim/optimizer/__init__.py b/clinicadl/optim/optimizer/__init__.py index 2c9cce3ba..504c60999 100644 --- a/clinicadl/optim/optimizer/__init__.py +++ b/clinicadl/optim/optimizer/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedOptimizer, OptimizerConfig +from .config import create_optimizer_config +from .enum import ImplementedOptimizer from .factory import get_optimizer diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optim/optimizer/factory.py index 15ffd8a52..47e2959d1 100644 --- a/clinicadl/optim/optimizer/factory.py +++ b/clinicadl/optim/optimizer/factory.py @@ -31,6 +31,11 @@ def get_optimizer( The updated config class: the arguments set to default will be updated with their effective values (the default values from the library). Useful for reproducibility. + + Raises + ------ + AttributeError + If a parameter group mentioned in the config class cannot be found in the network. """ optimizer_class = getattr(optim, config.optimizer) expected_args, default_args = get_args_and_defaults(optimizer_class.__init__) @@ -65,7 +70,7 @@ def get_optimizer( list_args_groups.append({"params": other_params}) optimizer = optimizer_class(list_args_groups, **args_global) - updated_config = OptimizerConfig(optimizer=config.optimizer, **default_args) + updated_config = config.model_copy(update=default_args) return optimizer, updated_config @@ -152,6 +157,11 @@ def _get_params_in_group( List[str] The name of all the parameters in the group. + Raises + ------ + AttributeError + If `group` cannot be found in the network. + Examples -------- >>> net = nn.Sequential( @@ -168,7 +178,12 @@ def _get_params_in_group( """ group_hierarchy = group.split(".") for name in group_hierarchy: - network = getattr(network, name) + try: + network = getattr(network, name) + except AttributeError as exc: + raise AttributeError( + f"There is no such group as {group} in the network." + ) from exc try: params = network.parameters() From 67c6efec63171d783941dc188b2f160cf24fabc8 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 19 Sep 2024 17:46:20 +0200 Subject: [PATCH 4/9] update optimizer unittests --- .../unittests/optim/optimizer/test_config.py | 141 ++++++++++++++---- .../unittests/optim/optimizer/test_factory.py | 53 +++---- 2 files changed, 136 insertions(+), 58 deletions(-) diff --git a/tests/unittests/optim/optimizer/test_config.py b/tests/unittests/optim/optimizer/test_config.py index e162fe58d..dd26d877e 100644 --- a/tests/unittests/optim/optimizer/test_config.py +++ b/tests/unittests/optim/optimizer/test_config.py @@ -1,34 +1,119 @@ import pytest +import torch from pydantic import ValidationError -from clinicadl.optim.optimizer import OptimizerConfig - - -def test_OptimizerConfig(): - config = OptimizerConfig( - optimizer="SGD", - lr=1e-3, - weight_decay={"param_0": 1e-3, "param_1": 1e-2}, - momentum={"param_1": 1e-1}, - lr_decay=1e-4, - ) - assert config.optimizer == "SGD" - assert config.lr == 1e-3 - assert config.weight_decay == {"param_0": 1e-3, "param_1": 1e-2} - assert config.momentum == {"param_1": 1e-1} - assert config.lr_decay == 1e-4 - assert config.alpha == "DefaultFromLibrary" - assert sorted(config.get_all_groups()) == ["param_0", "param_1"] +from clinicadl.optim.optimizer.config import ( + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + create_optimizer_config, +) +BAD_INPUTS = { + "lr": 0, + "rho": 1.1, + "eps": -0.1, + "weight_decay": -0.1, + "lr_decay": -0.1, + "initial_accumulator_value": -0.1, + "betas": (0.9, 1.0), + "alpha": 1.1, + "momentum": -0.1, + "dampening": 0.1, +} + +GOOD_INPUTS_1 = { + "lr": 0.1, + "rho": 0, + "eps": 0, + "weight_decay": 0, + "foreach": None, + "capturable": False, + "maximize": True, + "differentiable": False, + "fused": None, + "lr_decay": 0, + "initial_accumulator_value": 0, + "betas": (0.0, 0.0), + "amsgrad": True, + "alpha": 0.0, + "momentum": 0, + "centered": True, + "dampening": 0, + "nesterov": True, +} + +GOOD_INPUTS_2 = { + "foreach": True, + "fused": False, +} + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - OptimizerConfig(betas={"params_0": (0.9, 1.01), "params_1": (0.9, 0.99)}) - with pytest.raises(ValidationError): - OptimizerConfig(betas=0.9) - with pytest.raises(ValidationError): - OptimizerConfig(rho=1.01) - with pytest.raises(ValidationError): - OptimizerConfig(alpha=1.01) + config(**inputs) + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} with pytest.raises(ValidationError): - OptimizerConfig(dampening={"params_0": 0.1, "params_1": 2}) - with pytest.raises(ValueError): - OptimizerConfig(optimizer="abc") + config(**inputs) + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +@pytest.mark.parametrize( + "good_inputs", + [ + GOOD_INPUTS_1, + GOOD_INPUTS_2, + ], +) +def test_validation_pass(config, good_inputs): + fields = config.model_fields + inputs = {key: value for key, value in good_inputs.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("Adadelta", AdadeltaConfig), + ("Adagrad", AdagradConfig), + ("Adam", AdamConfig), + ("RMSprop", RMSpropConfig), + ("SGD", SGDConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_optimizer_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 6387b7f5a..c958e27d6 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -1,12 +1,25 @@ from collections import OrderedDict import pytest +import torch import torch.nn as nn +from torch.optim import Adagrad + +from clinicadl.optim.optimizer import ( + ImplementedOptimizer, + create_optimizer_config, + get_optimizer, +) +from clinicadl.optim.optimizer.factory import ( + _get_params_in_group, + _get_params_not_in_group, + _regroup_args, +) @pytest.fixture def network(): - network = nn.Sequential( + net = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 1, kernel_size=3)), @@ -14,31 +27,22 @@ def network(): ] ) ) - network.add_module( + net.add_module( "final", nn.Sequential( OrderedDict([("dense2", nn.Linear(10, 5)), ("dense3", nn.Linear(5, 3))]) ), ) - return network + return net def test_get_optimizer(network): - from torch.optim import Adagrad - - from clinicadl.optim.optimizer import ( - ImplementedOptimizer, - OptimizerConfig, - get_optimizer, - ) - - for optimizer in [e.value for e in ImplementedOptimizer]: - config = OptimizerConfig(optimizer=optimizer) + for optimizer in ImplementedOptimizer: + config = create_optimizer_config(optimizer=optimizer)() optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 - config = OptimizerConfig( - optimizer="Adagrad", + config = create_optimizer_config("Adagrad")( lr=1e-5, weight_decay={"final.dense3.weight": 1.0, "dense1": 0.1}, lr_decay={"dense1": 10, "ELSE": 100}, @@ -82,17 +86,16 @@ def test_get_optimizer(network): assert not updated_config.maximize assert not updated_config.differentiable - # special cases - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 1 + config = create_optimizer_config("Adagrad")( lr_decay={"ELSE": 100}, ) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 assert optimizer.param_groups[0]["lr_decay"] == 100 - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 2 + config = create_optimizer_config("Adagrad")( lr_decay={"conv1": 100, "dense1": 10, "final": 1}, ) optimizer, _ = get_optimizer(network, config) @@ -100,8 +103,6 @@ def test_get_optimizer(network): def test_regroup_args(): - from clinicadl.optim.optimizer.factory import _regroup_args - args = { "weight_decay": {"params_0": 0.0, "params_1": 1.0}, "alpha": {"params_1": 0.5, "ELSE": 0.1}, @@ -126,10 +127,6 @@ def test_regroup_args(): def test_get_params_in_block(network): - import torch - - from clinicadl.optim.optimizer.factory import _get_params_in_group - generator, list_layers = _get_params_in_group(network, "dense1") assert next(iter(generator)).shape == torch.Size((10, 10)) assert next(iter(generator)).shape == torch.Size((10,)) @@ -158,10 +155,6 @@ def test_get_params_in_block(network): def test_find_params_not_in_group(network): - import torch - - from clinicadl.optim.optimizer.factory import _get_params_not_in_group - params = _get_params_not_in_group( network, [ From 3873f482b92e986812544f0a7afdf02124f304b6 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Fri, 20 Sep 2024 13:29:53 +0200 Subject: [PATCH 5/9] enum for lr scheduler --- clinicadl/optim/lr_scheduler/enum.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 clinicadl/optim/lr_scheduler/enum.py diff --git a/clinicadl/optim/lr_scheduler/enum.py b/clinicadl/optim/lr_scheduler/enum.py new file mode 100644 index 000000000..a70bb1801 --- /dev/null +++ b/clinicadl/optim/lr_scheduler/enum.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class ImplementedLRScheduler(str, Enum): + """Implemented LR schedulers in ClinicaDL.""" + + CONSTANT = "ConstantLR" + LINEAR = "LinearLR" + STEP = "StepLR" + MULTI_STEP = "MultiStepLR" + PLATEAU = "ReduceLROnPlateau" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented LR schedulers are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class Mode(str, Enum): + """Supported mode for ReduceLROnPlateau.""" + + MIN = "min" + MAX = "max" + + +class ThresholdMode(str, Enum): + """Supported threshold mode for ReduceLROnPlateau.""" + + ABS = "abs" + REL = "rel" From 2245b6cca13a72bbb5b7b0320d4311860d7c7496 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Fri, 20 Sep 2024 13:30:21 +0200 Subject: [PATCH 6/9] config for all schedulers --- clinicadl/optim/lr_scheduler/config.py | 192 +++++++++++++++---------- 1 file changed, 119 insertions(+), 73 deletions(-) diff --git a/clinicadl/optim/lr_scheduler/config.py b/clinicadl/optim/lr_scheduler/config.py index 93fb3d9e1..073f92db7 100644 --- a/clinicadl/optim/lr_scheduler/config.py +++ b/clinicadl/optim/lr_scheduler/config.py @@ -1,7 +1,4 @@ -from __future__ import annotations - -from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from pydantic import ( BaseModel, @@ -10,71 +7,43 @@ NonNegativeInt, PositiveFloat, PositiveInt, + computed_field, field_validator, - model_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedLRScheduler, Mode, ThresholdMode -class ImplementedLRScheduler(str, Enum): - """Implemented LR schedulers in ClinicaDL.""" - - CONSTANT = "ConstantLR" - LINEAR = "LinearLR" - STEP = "StepLR" - MULTI_STEP = "MultiStepLR" - PLATEAU = "ReduceLROnPlateau" - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented LR schedulers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class Mode(str, Enum): - """Supported mode for ReduceLROnPlateau.""" - - MIN = "min" - MAX = "max" - - -class ThresholdMode(str, Enum): - """Supported threshold mode for ReduceLROnPlateau.""" - - ABS = "abs" - REL = "rel" +__all__ = [ + "LRSchedulerConfig", + "ConstantLRConfig", + "LinearLRConfig", + "StepLRConfig", + "MultiStepLRConfig", + "ReduceLROnPlateauConfig", + "create_lr_scheduler_config", +] class LRSchedulerConfig(BaseModel): - """Config class to configure the optimizer.""" + """Base config class for the LR scheduler.""" - scheduler: Optional[ImplementedLRScheduler] = None - step_size: Optional[PositiveInt] = None gamma: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - milestones: Optional[List[PositiveInt]] = None factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES total_iters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES last_epoch: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES - - mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES - patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES - cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - min_lr: Union[ - NonNegativeFloat, Dict[str, PositiveFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) + @computed_field + @property + def scheduler(self) -> Optional[ImplementedLRScheduler]: + """The name of the scheduler.""" + return None + @field_validator("last_epoch") @classmethod def validator_last_epoch(cls, v): @@ -84,31 +53,108 @@ def validator_last_epoch(cls, v): ), f"last_epoch must be -1 or a non-negative int but it has been set to {v}." return v - @field_validator("milestones") + +class ConstantLRConfig(LRSchedulerConfig): + """Config class for ConstantLR scheduler.""" + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.CONSTANT + + +class LinearLRConfig(LRSchedulerConfig): + """Config class for LinearLR scheduler.""" + + start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.LINEAR + + +class StepLRConfig(LRSchedulerConfig): + """Config class for StepLR scheduler.""" + + step_size: PositiveInt + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.STEP + + +class MultiStepLRConfig(LRSchedulerConfig): + """Config class for MultiStepLR scheduler.""" + + milestones: List[PositiveInt] + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.MULTI_STEP + + @field_validator("milestones", mode="after") @classmethod def validator_milestones(cls, v): import numpy as np - if v is not None: - assert len(np.unique(v)) == len( - v - ), "Epoch(s) in milestones should be unique." - return sorted(v) - return v + assert len(np.unique(v)) == len(v), "Epoch(s) in milestones should be unique." + return sorted(v) + + +class ReduceLROnPlateauConfig(LRSchedulerConfig): + """Config class for ReduceLROnPlateau scheduler.""" + + mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES + patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES + cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES + min_lr: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - @model_validator(mode="after") - def check_mandatory_args(self) -> LRSchedulerConfig: - if ( - self.scheduler == ImplementedLRScheduler.MULTI_STEP - and self.milestones is None - ): - raise ValueError( - """If you chose MultiStepLR as LR scheduler, you should pass milestones - (see PyTorch documentation for more details).""" - ) - elif self.scheduler == ImplementedLRScheduler.STEP and self.step_size is None: - raise ValueError( - """If you chose StepLR as LR scheduler, you should pass a step_size - (see PyTorch documentation for more details).""" - ) - return self + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.PLATEAU + + +def create_lr_scheduler_config( + scheduler: Optional[Union[str, ImplementedLRScheduler]], +) -> Type[LRSchedulerConfig]: + """ + A factory function to create a config class suited for the LR scheduler. + + Parameters + ---------- + scheduler : Optional[Union[str, ImplementedLRScheduler]] + The name of the LR scheduler. + Can be None if no LR scheduler will be used. + + Returns + ------- + Type[LRSchedulerConfig] + The config class. + + Raises + ------ + ValueError + If `scheduler` is not supported. + """ + if scheduler is None: + return LRSchedulerConfig + + scheduler = ImplementedLRScheduler(scheduler) + config_name = "".join([scheduler, "Config"]) + config = globals()[config_name] + + return config From 37ec4a2dc5171ee420fe9cd4d87149d06233a57c Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Fri, 20 Sep 2024 13:30:53 +0200 Subject: [PATCH 7/9] fix factory lr schedulers --- clinicadl/optim/lr_scheduler/__init__.py | 3 ++- clinicadl/optim/lr_scheduler/factory.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/clinicadl/optim/lr_scheduler/__init__.py b/clinicadl/optim/lr_scheduler/__init__.py index 9c5fabd2c..c26899e69 100644 --- a/clinicadl/optim/lr_scheduler/__init__.py +++ b/clinicadl/optim/lr_scheduler/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedLRScheduler, LRSchedulerConfig +from .config import create_lr_scheduler_config +from .enum import ImplementedLRScheduler from .factory import get_lr_scheduler diff --git a/clinicadl/optim/lr_scheduler/factory.py b/clinicadl/optim/lr_scheduler/factory.py index f07b07f32..9ffd73007 100644 --- a/clinicadl/optim/lr_scheduler/factory.py +++ b/clinicadl/optim/lr_scheduler/factory.py @@ -51,6 +51,6 @@ def get_lr_scheduler( ) # ELSE must be the last group scheduler = scheduler_class(optimizer, **config_dict_) - updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict) + updated_config = config.model_copy(update=config_dict) return scheduler, updated_config From 926aecb739a07d9b3122900b5bacaf2a65cc102b Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Fri, 20 Sep 2024 13:31:10 +0200 Subject: [PATCH 8/9] unittests --- clinicadl/optim/__init__.py | 3 + .../optim/lr_scheduler/test_config.py | 139 ++++++++++++++---- .../optim/lr_scheduler/test_factory.py | 20 +-- .../unittests/optim/optimizer/test_config.py | 1 - 4 files changed, 119 insertions(+), 44 deletions(-) diff --git a/clinicadl/optim/__init__.py b/clinicadl/optim/__init__.py index 6715835a5..185b1c418 100644 --- a/clinicadl/optim/__init__.py +++ b/clinicadl/optim/__init__.py @@ -1 +1,4 @@ from .config import OptimizationConfig +from .early_stopping import EarlyStopping +from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler +from .optimizer import create_optimizer_config, get_optimizer diff --git a/tests/unittests/optim/lr_scheduler/test_config.py b/tests/unittests/optim/lr_scheduler/test_config.py index 270ccc27c..dbf96ccc8 100644 --- a/tests/unittests/optim/lr_scheduler/test_config.py +++ b/tests/unittests/optim/lr_scheduler/test_config.py @@ -1,37 +1,114 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.lr_scheduler import LRSchedulerConfig - - -def test_LRSchedulerConfig(): - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", - mode="max", - patience=1, - threshold_mode="rel", - milestones=[4, 3, 2], - min_lr={"param_0": 1e-1, "ELSE": 1e-2}, - ) - assert config.scheduler == "ReduceLROnPlateau" - assert config.mode == "max" - assert config.patience == 1 - assert config.threshold_mode == "rel" - assert config.milestones == [2, 3, 4] - assert config.min_lr == {"param_0": 1e-1, "ELSE": 1e-2} - assert config.threshold == "DefaultFromLibrary" +from clinicadl.optim.lr_scheduler.config import ( + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + create_lr_scheduler_config, +) +BAD_INPUTS = { + "milestones": [3, 2, 4], + "gamma": 0, + "last_epoch": -2, + "step_size": 0, + "factor": 0, + "total_iters": 0, + "start_factor": 0, + "end_factor": 0, + "mode": "abc", + "patience": 0, + "threshold": -1, + "threshold_mode": "abc", + "cooldown": -1, + "eps": -0.1, + "min_lr": -0.1, +} + +GOOD_INPUTS = { + "milestones": [1, 4, 5], + "gamma": 0.1, + "last_epoch": -1, + "step_size": 1, + "factor": 0.1, + "total_iters": 1, + "start_factor": 0.1, + "end_factor": 0.2, + "mode": "min", + "patience": 1, + "threshold": 0, + "threshold_mode": "abs", + "cooldown": 0, + "eps": 0, + "min_lr": 0, +} + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - LRSchedulerConfig(last_epoch=-2) - with pytest.raises(ValueError): - LRSchedulerConfig(scheduler="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(mode="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(threshold_mode="abc") - with pytest.raises(ValidationError): - LRSchedulerConfig(milestones=[10, 10]) - with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="MultiStepLR") + config(**inputs) + + # test dict inputs for min_lr + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + with pytest.raises(ValidationError): + config(**inputs) + + +def test_validation_fail_special(): with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="StepLR") + MultiStepLRConfig(milestones=[0, 1]) + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_pass(config): + fields = config.model_fields + inputs = {key: value for key, value in GOOD_INPUTS.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + c = config(**inputs) + assert getattr(c, "min_lr") == inputs["min_lr"] + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("ConstantLR", ConstantLRConfig), + ("LinearLR", LinearLRConfig), + ("MultiStepLR", MultiStepLRConfig), + ("ReduceLROnPlateau", ReduceLROnPlateauConfig), + ("StepLR", StepLRConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_lr_scheduler_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/lr_scheduler/test_factory.py b/tests/unittests/optim/lr_scheduler/test_factory.py index c59759fcd..771d844f7 100644 --- a/tests/unittests/optim/lr_scheduler/test_factory.py +++ b/tests/unittests/optim/lr_scheduler/test_factory.py @@ -6,7 +6,7 @@ from clinicadl.optim.lr_scheduler import ( ImplementedLRScheduler, - LRSchedulerConfig, + create_lr_scheduler_config, get_lr_scheduler, ) @@ -33,17 +33,12 @@ def test_get_lr_scheduler(): lr=10.0, ) - for scheduler in [e.value for e in ImplementedLRScheduler]: - if scheduler == "StepLR": - config = LRSchedulerConfig(scheduler=scheduler, step_size=1) - elif scheduler == "MultiStepLR": - config = LRSchedulerConfig(scheduler=scheduler, milestones=[1, 2, 3]) - else: - config = LRSchedulerConfig(scheduler=scheduler) - get_lr_scheduler(optimizer, config) + args = {"step_size": 1, "milestones": [1, 2]} + for scheduler in ImplementedLRScheduler: + config = create_lr_scheduler_config(scheduler=scheduler)(**args) + _ = get_lr_scheduler(optimizer, config) - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", + config = create_lr_scheduler_config(scheduler="ReduceLROnPlateau")( mode="max", factor=0.123, threshold=1e-1, @@ -77,7 +72,8 @@ def test_get_lr_scheduler(): scheduler, updated_config = get_lr_scheduler(optimizer, config) assert scheduler.min_lrs == [0.1, 0.01, 1] - config = LRSchedulerConfig() + # no lr scheduler + config = create_lr_scheduler_config(None)() scheduler, updated_config = get_lr_scheduler(optimizer, config) assert isinstance(scheduler, LambdaLR) assert updated_config.scheduler is None diff --git a/tests/unittests/optim/optimizer/test_config.py b/tests/unittests/optim/optimizer/test_config.py index dd26d877e..bf1dbcd8f 100644 --- a/tests/unittests/optim/optimizer/test_config.py +++ b/tests/unittests/optim/optimizer/test_config.py @@ -1,5 +1,4 @@ import pytest -import torch from pydantic import ValidationError from clinicadl.optim.optimizer.config import ( From 00752ff451b007e5ec9c2df1b5bc761c75040dbd Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Mon, 23 Sep 2024 19:02:12 +0200 Subject: [PATCH 9/9] fix import issue --- clinicadl/optim/optimizer/factory.py | 3 ++- tests/unittests/optim/optimizer/test_factory.py | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optim/optimizer/factory.py index b7dee05de..3123781ed 100644 --- a/clinicadl/optim/optimizer/factory.py +++ b/clinicadl/optim/optimizer/factory.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Tuple +import torch import torch.nn as nn import torch.optim as optim diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 935e1c1ea..47b44a00a 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -102,10 +102,7 @@ def test_get_optimizer(network): assert len(optimizer.param_groups) == 3 # special case : no ELSE mentioned - config = OptimizerConfig( - optimizer="Adagrad", - lr_decay={"conv1": 100}, - ) + config = create_optimizer_config("Adagrad")(lr_decay={"conv1": 100}) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr_decay"] == 100