diff --git a/clinicadl/losses/__init__.py b/clinicadl/losses/__init__.py index 5f0998372..ed3a07e0f 100644 --- a/clinicadl/losses/__init__.py +++ b/clinicadl/losses/__init__.py @@ -1,2 +1,3 @@ -from .config import ClassificationLoss, ImplementedLoss, LossConfig +from .config import create_loss_config +from .enum import ClassificationLoss, ImplementedLoss from .factory import get_loss_function diff --git a/clinicadl/losses/config.py b/clinicadl/losses/config.py index d6f5c45c7..0b6a698f4 100644 --- a/clinicadl/losses/config.py +++ b/clinicadl/losses/config.py @@ -1,87 +1,87 @@ -from enum import Enum -from typing import List, Optional, Union +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Type, Union from pydantic import ( BaseModel, ConfigDict, NonNegativeFloat, PositiveFloat, + computed_field, field_validator, - model_validator, ) -from clinicadl.utils.enum import BaseEnum from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedLoss, Order, Reduction -class ClassificationLoss(str, BaseEnum): - """Losses that can be used only for classification.""" +__all__ = [ + "LossConfig", + "NLLLossConfig", + "CrossEntropyLossConfig", + "BCELossConfig", + "BCEWithLogitsLossConfig", + "MultiMarginLossConfig", + "KLDivLossConfig", + "HuberLossConfig", + "SmoothL1LossConfig", + "L1LossConfig", + "MSELossConfig", + "create_loss_config", +] - CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel) - MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel) - BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape - BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically +class LossConfig(BaseModel, ABC): + """Base config class for the loss function.""" -class ImplementedLoss(str, Enum): - """Implemented losses in ClinicaDL.""" - - CROSS_ENTROPY = "CrossEntropyLoss" - MULTI_MARGIN = "MultiMarginLoss" - BCE = "BCELoss" - BCE_LOGITS = "BCEWithLogitsLoss" - L1 = "L1Loss" - MSE = "MSELoss" - HUBER = "HuberLoss" - SMOOTH_L1 = "SmoothL1Loss" - KLDIV = "KLDivLoss" # if log_target=False, target must be positive - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented losses are: " - + ", ".join([repr(m.value) for m in cls]) - ) + reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES + weight: Union[ + Optional[List[NonNegativeFloat]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + # pydantic config + model_config = ConfigDict( + validate_assignment=True, use_enum_values=True, validate_default=True + ) + @computed_field + @property + @abstractmethod + def loss(self) -> ImplementedLoss: + """ImplementedLoss.e name of the loss.""" -class Reduction(str, Enum): - """Supported reduction method in ClinicaDL.""" - MEAN = "mean" - SUM = "sum" +class NLLLossConfig(LossConfig): + """Config class for Negative Log Likelihood loss.""" + ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES -class Order(int, Enum): - """Supported order of L-norm for MultiMarginLoss.""" + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.NLL - ONE = 1 - TWO = 2 + @field_validator("ignore_index") + @classmethod + def validator_ignore_index(cls, v): + if isinstance(v, int): + assert ( + v == -100 or 0 <= v + ), "ignore_index must be a positive int (or -100 when disabled)." + return v -class LossConfig(BaseModel): - """Config class to configure the loss function.""" +class CrossEntropyLossConfig(NLLLossConfig): + """Config class for Cross Entropy loss.""" - loss: ImplementedLoss = ImplementedLoss.MSE - reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES - delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - beta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES - margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES - weight: Union[ - Optional[List[NonNegativeFloat]], DefaultFromLibrary - ] = DefaultFromLibrary.YES # a weight for each class - ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES label_smoothing: Union[ NonNegativeFloat, DefaultFromLibrary ] = DefaultFromLibrary.YES - log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - pos_weight: Union[ - Optional[List[NonNegativeFloat]], DefaultFromLibrary - ] = DefaultFromLibrary.YES # a positive weight for each class - # pydantic config - model_config = ConfigDict( - validate_assignment=True, use_enum_values=True, validate_default=True - ) + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.CROSS_ENTROPY @field_validator("label_smoothing") @classmethod @@ -92,26 +92,150 @@ def validator_label_smoothing(cls, v): ), f"label_smoothing must be between 0 and 1 but it has been set to {v}." return v - @field_validator("ignore_index") + +class BCELossConfig(LossConfig): + """Config class for Binary Cross Entropy loss.""" + + weight: Optional[List[NonNegativeFloat]] = None + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.BCE + + @field_validator("weight") @classmethod - def validator_ignore_index(cls, v): - if isinstance(v, int): - assert ( - v == -100 or 0 <= v - ), "ignore_index must be a positive int (or -100 when disabled)." + def validator_weight(cls, v): + if v is not None: + raise ValueError( + "Cannot use weight with BCEWithLogitsLoss. If you want more flexibility, please use API mode." + ) return v - @model_validator(mode="after") - def model_validator(self): - if ( - self.loss == ImplementedLoss.BCE_LOGITS - and self.weight is not None - and self.weight != DefaultFromLibrary.YES - ): - raise ValueError("Cannot use weight with BCEWithLogitsLoss.") - elif ( - self.loss == ImplementedLoss.BCE - and self.weight is not None - and self.weight != DefaultFromLibrary.YES - ): - raise ValueError("Cannot use weight with BCELoss.") + +class BCEWithLogitsLossConfig(BCELossConfig): + """Config class for Binary Cross Entropy With Logits loss.""" + + pos_weight: Union[Optional[List[Any]], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.BCE_LOGITS + + @field_validator("pos_weight") + @classmethod + def validator_pos_weight(cls, v): + if isinstance(v, list): + check = cls._recursive_float_check(v) + if not check: + raise ValueError( + f"elements in pos_weight must be non-negative float, got: {v}" + ) + return v + + @classmethod + def _recursive_float_check(cls, item): + if isinstance(item, list): + return all(cls._recursive_float_check(i) for i in item) + else: + return (isinstance(item, float) or isinstance(item, int)) and item >= 0 + + +class MultiMarginLossConfig(LossConfig): + """Config class for Multi Margin loss.""" + + p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES + margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.MULTI_MARGIN + + +class KLDivLossConfig(LossConfig): + """Config class for Kullback-Leibler Divergence loss.""" + + log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.KLDIV + + +class HuberLossConfig(LossConfig): + """Config class for Huber loss.""" + + delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.HUBER + + +class SmoothL1LossConfig(LossConfig): + """Config class for Smooth L1 loss.""" + + beta: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.SMOOTH_L1 + + +class L1LossConfig(LossConfig): + """Config class for L1 loss.""" + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.L1 + + +class MSELossConfig(LossConfig): + """Config class for Mean Squared Error loss.""" + + @computed_field + @property + def loss(self) -> ImplementedLoss: + """The name of the loss.""" + return ImplementedLoss.MSE + + +def create_loss_config( + loss: Union[str, ImplementedLoss], +) -> Type[LossConfig]: + """ + A factory function to create a config class suited for the loss. + + Parameters + ---------- + loss : Union[str, ImplementedLoss] + The name of the loss. + + Returns + ------- + Type[LossConfig] + The config class. + + Raises + ------ + ValueError + If `loss` is not supported. + """ + loss = ImplementedLoss(loss) + config_name = "".join([loss, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/losses/enum.py b/clinicadl/losses/enum.py new file mode 100644 index 000000000..a38ff2707 --- /dev/null +++ b/clinicadl/losses/enum.py @@ -0,0 +1,50 @@ +from enum import Enum + +from clinicadl.utils.enum import BaseEnum + + +class ClassificationLoss(str, BaseEnum): + """Losses that can be used only for classification.""" + + CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel) + NLL = "NLLLoss" # for multi-class classification, inputs are log-probabilities and targets are int (same dimension without the class channel) + MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel) + BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape + BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically + + +class ImplementedLoss(str, Enum): + """Implemented losses in ClinicaDL.""" + + CROSS_ENTROPY = "CrossEntropyLoss" + NLL = "NLLLoss" + MULTI_MARGIN = "MultiMarginLoss" + BCE = "BCELoss" + BCE_LOGITS = "BCEWithLogitsLoss" + + L1 = "L1Loss" + MSE = "MSELoss" + HUBER = "HuberLoss" + SMOOTH_L1 = "SmoothL1Loss" + KLDIV = "KLDivLoss" # if log_target=False, target must be positive + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented losses are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class Reduction(str, Enum): + """Supported reduction method in ClinicaDL.""" + + MEAN = "mean" + SUM = "sum" + + +class Order(int, Enum): + """Supported order of L-norm for MultiMarginLoss.""" + + ONE = 1 + TWO = 2 diff --git a/clinicadl/losses/factory.py b/clinicadl/losses/factory.py index a4758e716..6cca92476 100644 --- a/clinicadl/losses/factory.py +++ b/clinicadl/losses/factory.py @@ -39,6 +39,6 @@ def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, LossConfig]: config_dict_["pos_weight"] = torch.Tensor(config_dict_["pos_weight"]) loss = loss_class(**config_dict_) - updated_config = LossConfig(loss=config.loss, **config_dict) + updated_config = config.model_copy(update=config_dict) return loss, updated_config diff --git a/clinicadl/optim/__init__.py b/clinicadl/optim/__init__.py index 6715835a5..185b1c418 100644 --- a/clinicadl/optim/__init__.py +++ b/clinicadl/optim/__init__.py @@ -1 +1,4 @@ from .config import OptimizationConfig +from .early_stopping import EarlyStopping +from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler +from .optimizer import create_optimizer_config, get_optimizer diff --git a/clinicadl/optim/lr_scheduler/__init__.py b/clinicadl/optim/lr_scheduler/__init__.py index 9c5fabd2c..c26899e69 100644 --- a/clinicadl/optim/lr_scheduler/__init__.py +++ b/clinicadl/optim/lr_scheduler/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedLRScheduler, LRSchedulerConfig +from .config import create_lr_scheduler_config +from .enum import ImplementedLRScheduler from .factory import get_lr_scheduler diff --git a/clinicadl/optim/lr_scheduler/config.py b/clinicadl/optim/lr_scheduler/config.py index 93fb3d9e1..073f92db7 100644 --- a/clinicadl/optim/lr_scheduler/config.py +++ b/clinicadl/optim/lr_scheduler/config.py @@ -1,7 +1,4 @@ -from __future__ import annotations - -from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from pydantic import ( BaseModel, @@ -10,71 +7,43 @@ NonNegativeInt, PositiveFloat, PositiveInt, + computed_field, field_validator, - model_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedLRScheduler, Mode, ThresholdMode -class ImplementedLRScheduler(str, Enum): - """Implemented LR schedulers in ClinicaDL.""" - - CONSTANT = "ConstantLR" - LINEAR = "LinearLR" - STEP = "StepLR" - MULTI_STEP = "MultiStepLR" - PLATEAU = "ReduceLROnPlateau" - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented LR schedulers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class Mode(str, Enum): - """Supported mode for ReduceLROnPlateau.""" - - MIN = "min" - MAX = "max" - - -class ThresholdMode(str, Enum): - """Supported threshold mode for ReduceLROnPlateau.""" - - ABS = "abs" - REL = "rel" +__all__ = [ + "LRSchedulerConfig", + "ConstantLRConfig", + "LinearLRConfig", + "StepLRConfig", + "MultiStepLRConfig", + "ReduceLROnPlateauConfig", + "create_lr_scheduler_config", +] class LRSchedulerConfig(BaseModel): - """Config class to configure the optimizer.""" + """Base config class for the LR scheduler.""" - scheduler: Optional[ImplementedLRScheduler] = None - step_size: Optional[PositiveInt] = None gamma: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - milestones: Optional[List[PositiveInt]] = None factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES total_iters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES last_epoch: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES - - mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES - patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES - cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - min_lr: Union[ - NonNegativeFloat, Dict[str, PositiveFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) + @computed_field + @property + def scheduler(self) -> Optional[ImplementedLRScheduler]: + """The name of the scheduler.""" + return None + @field_validator("last_epoch") @classmethod def validator_last_epoch(cls, v): @@ -84,31 +53,108 @@ def validator_last_epoch(cls, v): ), f"last_epoch must be -1 or a non-negative int but it has been set to {v}." return v - @field_validator("milestones") + +class ConstantLRConfig(LRSchedulerConfig): + """Config class for ConstantLR scheduler.""" + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.CONSTANT + + +class LinearLRConfig(LRSchedulerConfig): + """Config class for LinearLR scheduler.""" + + start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.LINEAR + + +class StepLRConfig(LRSchedulerConfig): + """Config class for StepLR scheduler.""" + + step_size: PositiveInt + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.STEP + + +class MultiStepLRConfig(LRSchedulerConfig): + """Config class for MultiStepLR scheduler.""" + + milestones: List[PositiveInt] + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.MULTI_STEP + + @field_validator("milestones", mode="after") @classmethod def validator_milestones(cls, v): import numpy as np - if v is not None: - assert len(np.unique(v)) == len( - v - ), "Epoch(s) in milestones should be unique." - return sorted(v) - return v + assert len(np.unique(v)) == len(v), "Epoch(s) in milestones should be unique." + return sorted(v) + + +class ReduceLROnPlateauConfig(LRSchedulerConfig): + """Config class for ReduceLROnPlateau scheduler.""" + + mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES + patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES + cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES + min_lr: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - @model_validator(mode="after") - def check_mandatory_args(self) -> LRSchedulerConfig: - if ( - self.scheduler == ImplementedLRScheduler.MULTI_STEP - and self.milestones is None - ): - raise ValueError( - """If you chose MultiStepLR as LR scheduler, you should pass milestones - (see PyTorch documentation for more details).""" - ) - elif self.scheduler == ImplementedLRScheduler.STEP and self.step_size is None: - raise ValueError( - """If you chose StepLR as LR scheduler, you should pass a step_size - (see PyTorch documentation for more details).""" - ) - return self + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.PLATEAU + + +def create_lr_scheduler_config( + scheduler: Optional[Union[str, ImplementedLRScheduler]], +) -> Type[LRSchedulerConfig]: + """ + A factory function to create a config class suited for the LR scheduler. + + Parameters + ---------- + scheduler : Optional[Union[str, ImplementedLRScheduler]] + The name of the LR scheduler. + Can be None if no LR scheduler will be used. + + Returns + ------- + Type[LRSchedulerConfig] + The config class. + + Raises + ------ + ValueError + If `scheduler` is not supported. + """ + if scheduler is None: + return LRSchedulerConfig + + scheduler = ImplementedLRScheduler(scheduler) + config_name = "".join([scheduler, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/optim/lr_scheduler/enum.py b/clinicadl/optim/lr_scheduler/enum.py new file mode 100644 index 000000000..a70bb1801 --- /dev/null +++ b/clinicadl/optim/lr_scheduler/enum.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class ImplementedLRScheduler(str, Enum): + """Implemented LR schedulers in ClinicaDL.""" + + CONSTANT = "ConstantLR" + LINEAR = "LinearLR" + STEP = "StepLR" + MULTI_STEP = "MultiStepLR" + PLATEAU = "ReduceLROnPlateau" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented LR schedulers are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class Mode(str, Enum): + """Supported mode for ReduceLROnPlateau.""" + + MIN = "min" + MAX = "max" + + +class ThresholdMode(str, Enum): + """Supported threshold mode for ReduceLROnPlateau.""" + + ABS = "abs" + REL = "rel" diff --git a/clinicadl/optim/lr_scheduler/factory.py b/clinicadl/optim/lr_scheduler/factory.py index a26948deb..2eeaccd55 100644 --- a/clinicadl/optim/lr_scheduler/factory.py +++ b/clinicadl/optim/lr_scheduler/factory.py @@ -56,6 +56,6 @@ def get_lr_scheduler( config_dict_["min_lr"].append(default_min_lr) scheduler = scheduler_class(optimizer, **config_dict_) - updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict) + updated_config = config.model_copy(update=config_dict) return scheduler, updated_config diff --git a/clinicadl/optim/optimizer/__init__.py b/clinicadl/optim/optimizer/__init__.py index 2c9cce3ba..504c60999 100644 --- a/clinicadl/optim/optimizer/__init__.py +++ b/clinicadl/optim/optimizer/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedOptimizer, OptimizerConfig +from .config import create_optimizer_config +from .enum import ImplementedOptimizer from .factory import get_optimizer diff --git a/clinicadl/optim/optimizer/config.py b/clinicadl/optim/optimizer/config.py index 46aa5958c..b0a55f034 100644 --- a/clinicadl/optim/optimizer/config.py +++ b/clinicadl/optim/optimizer/config.py @@ -1,38 +1,32 @@ -from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Type, Union from pydantic import ( BaseModel, ConfigDict, NonNegativeFloat, PositiveFloat, + computed_field, field_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedOptimizer -class ImplementedOptimizer(str, Enum): - """Implemented optimizers in ClinicaDL.""" +__all__ = [ + "OptimizerConfig", + "AdadeltaConfig", + "AdagradConfig", + "AdamConfig", + "RMSpropConfig", + "SGDConfig", + "create_optimizer_config", +] - ADADELTA = "Adadelta" - ADAGRAD = "Adagrad" - ADAM = "Adam" - RMS_PROP = "RMSprop" - SGD = "SGD" - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented optimizers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class OptimizerConfig(BaseModel): - """Config class to configure the optimizer.""" - - optimizer: ImplementedOptimizer = ImplementedOptimizer.ADAM +class OptimizerConfig(BaseModel, ABC): + """Base config class for the optimizer.""" lr: Union[ PositiveFloat, Dict[str, PositiveFloat], DefaultFromLibrary @@ -40,36 +34,9 @@ class OptimizerConfig(BaseModel): weight_decay: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - betas: Union[ - Tuple[NonNegativeFloat, NonNegativeFloat], - Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - alpha: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - momentum: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - rho: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - lr_decay: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES eps: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - dampening: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - initial_accumulator_value: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES foreach: Union[ Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES @@ -81,14 +48,19 @@ class OptimizerConfig(BaseModel): bool, Dict[str, bool], DefaultFromLibrary ] = DefaultFromLibrary.YES fused: Union[ - Optional[bool], Dict[str, bool], DefaultFromLibrary + Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) - @field_validator("betas", "rho", "alpha", "dampening") + @computed_field + @property + @abstractmethod + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + @classmethod def validator_proba(cls, v, ctx): name = ctx.field_name @@ -128,3 +100,131 @@ def get_all_groups(self) -> List[str]: groups.update(set(value.keys())) return list(groups) + + +class AdadeltaConfig(OptimizerConfig): + """Config class for Adadelta optimizer.""" + + rho: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADADELTA + + @field_validator("rho") + def validator_rho(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class AdagradConfig(OptimizerConfig): + """Config class for Adagrad optimizer.""" + + lr_decay: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + initial_accumulator_value: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAGRAD + + +class AdamConfig(OptimizerConfig): + """Config class for Adam optimizer.""" + + betas: Union[ + Tuple[NonNegativeFloat, NonNegativeFloat], + Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], + DefaultFromLibrary, + ] = DefaultFromLibrary.YES + amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAM + + @field_validator("betas") + def validator_betas(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class RMSpropConfig(OptimizerConfig): + """Config class for RMSprop optimizer.""" + + alpha: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.RMS_PROP + + @field_validator("alpha") + def validator_alpha(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class SGDConfig(OptimizerConfig): + """Config class for SGD optimizer.""" + + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dampening: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.SGD + + @field_validator("dampening") + def validator_dampening(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +def create_optimizer_config( + optimizer: Union[str, ImplementedOptimizer], +) -> Type[OptimizerConfig]: + """ + A factory function to create a config class suited for the optimizer. + + Parameters + ---------- + optimizer : Union[str, ImplementedOptimizer] + The name of the optimizer. + + Returns + ------- + Type[OptimizerConfig] + The config class. + + Raises + ------ + ValueError + If `optimizer` is not supported. + """ + optimizer = ImplementedOptimizer(optimizer) + config_name = "".join([optimizer, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/optim/optimizer/enum.py b/clinicadl/optim/optimizer/enum.py new file mode 100644 index 000000000..a397dbe85 --- /dev/null +++ b/clinicadl/optim/optimizer/enum.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class ImplementedOptimizer(str, Enum): + """Implemented optimizers in ClinicaDL.""" + + ADADELTA = "Adadelta" + ADAGRAD = "Adagrad" + ADAM = "Adam" + RMS_PROP = "RMSprop" + SGD = "SGD" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented optimizers are: " + + ", ".join([repr(m.value) for m in cls]) + ) diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optim/optimizer/factory.py index 3afd6a848..3123781ed 100644 --- a/clinicadl/optim/optimizer/factory.py +++ b/clinicadl/optim/optimizer/factory.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Tuple +import torch import torch.nn as nn import torch.optim as optim @@ -31,6 +32,11 @@ def get_optimizer( The updated config class: the arguments set to default will be updated with their effective values (the default values from the library). Useful for reproducibility. + + Raises + ------ + AttributeError + If a parameter group mentioned in the config class cannot be found in the network. """ optimizer_class = getattr(optim, config.optimizer) expected_args, default_args = get_args_and_defaults(optimizer_class.__init__) @@ -58,7 +64,7 @@ def get_optimizer( list_args_groups.append({"params": other_params}) optimizer = optimizer_class(list_args_groups, **args_global) - updated_config = OptimizerConfig(optimizer=config.optimizer, **default_args) + updated_config = config.model_copy(update=default_args) return optimizer, updated_config @@ -119,3 +125,90 @@ def _regroup_args( args_global[arg] = value return args_groups, args_global + + +def _get_params_in_group( + network: nn.Module, group: str +) -> Tuple[Iterator[torch.Tensor], List[str]]: + """ + Gets the parameters of a specific group of a neural network. + + Parameters + ---------- + network : nn.Module + The neural network. + group : str + The name of the group, e.g. a layer or a block. + If it is a sub-block, the hierarchy should be + specified with "." (see examples). + Will work even if the group is reduced to a base layer + (e.g. group = "dense.weight" or "dense.bias"). + + Returns + ------- + Iterator[torch.Tensor] + A generator that contains the parameters of the group. + List[str] + The name of all the parameters in the group. + + Raises + ------ + AttributeError + If `group` cannot be found in the network. + + Examples + -------- + >>> net = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 1, kernel_size=3)), + ("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))), + ] + ) + ) + >>> generator, params_names = _get_params_in_group(network, "final.dense1") + >>> params_names + ["final.dense1.weight", "final.dense1.bias"] + """ + group_hierarchy = group.split(".") + for name in group_hierarchy: + try: + network = getattr(network, name) + except AttributeError as exc: + raise AttributeError( + f"There is no such group as {group} in the network." + ) from exc + + try: + params = network.parameters() + params_names = [ + ".".join([group, name]) for name, _ in network.named_parameters() + ] + except AttributeError: # we already reached params + params = (param for param in [network]) + params_names = [group] + + return params, params_names + + +def _get_params_not_in_group( + network: nn.Module, group: Iterable[str] +) -> Iterator[torch.Tensor]: + """ + Finds the parameters of a neural networks that + are not in a group. + + Parameters + ---------- + network : nn.Module + The neural network. + group : List[str] + The group of parameters. + + Returns + ------- + Iterator[torch.Tensor] + A generator of all the parameters that are not in the input + group. + """ + return (param[1] for param in network.named_parameters() if param[0] not in group) diff --git a/tests/unittests/losses/test_config.py b/tests/unittests/losses/test_config.py index fc40af9e7..237f07922 100644 --- a/tests/unittests/losses/test_config.py +++ b/tests/unittests/losses/test_config.py @@ -1,35 +1,104 @@ import pytest from pydantic import ValidationError -from clinicadl.losses import LossConfig +from clinicadl.losses import ImplementedLoss +from clinicadl.losses.config import ( + BCELossConfig, + BCEWithLogitsLossConfig, + CrossEntropyLossConfig, + HuberLossConfig, + KLDivLossConfig, + L1LossConfig, + MSELossConfig, + MultiMarginLossConfig, + NLLLossConfig, + SmoothL1LossConfig, + create_loss_config, +) -def test_LossConfig(): - config = LossConfig( - loss="SmoothL1Loss", margin=10.0, delta=2.0, reduction="sum", weight=None - ) - assert config.loss == "SmoothL1Loss" - assert config.margin == 10.0 - assert config.delta == 2.0 - assert config.reduction == "sum" - assert config.p == "DefaultFromLibrary" - - with pytest.raises(ValueError): - LossConfig(loss="abc") - with pytest.raises(ValueError): - LossConfig(weight=[0.1, -0.1, 0.8]) - with pytest.raises(ValueError): - LossConfig(p=3) - with pytest.raises(ValueError): - LossConfig(reduction="abc") - with pytest.raises(ValidationError): - LossConfig(label_smoothing=1.1) - with pytest.raises(ValidationError): - LossConfig(ignore_index=-1) - with pytest.raises(ValidationError): - LossConfig(loss="BCEWithLogitsLoss", weight=[1, 2, 3]) +@pytest.mark.parametrize( + "config,args", + [ + (L1LossConfig, {"reduction": "none"}), + (MSELossConfig, {"reduction": "none"}), + (CrossEntropyLossConfig, {"reduction": "none"}), + (CrossEntropyLossConfig, {"weight": [1, -1, 2]}), + (CrossEntropyLossConfig, {"ignore_index": -1}), + (CrossEntropyLossConfig, {"label_smoothing": 1.1}), + (NLLLossConfig, {"reduction": "none"}), + (NLLLossConfig, {"weight": [1, -1, 2]}), + (NLLLossConfig, {"ignore_index": -1}), + (KLDivLossConfig, {"reduction": "none"}), + (BCELossConfig, {"reduction": "none"}), + (BCELossConfig, {"weight": [0, 1]}), + (BCEWithLogitsLossConfig, {"reduction": "none"}), + (BCEWithLogitsLossConfig, {"weight": [0, 1]}), + (BCEWithLogitsLossConfig, {"pos_weight": [[1, -1, 2]]}), + (BCEWithLogitsLossConfig, {"pos_weight": [["a", "b"]]}), + (HuberLossConfig, {"reduction": "none"}), + (HuberLossConfig, {"delta": 0.0}), + (SmoothL1LossConfig, {"reduction": "none"}), + (SmoothL1LossConfig, {"beta": -1.0}), + (MultiMarginLossConfig, {"reduction": "none"}), + (MultiMarginLossConfig, {"p": 3}), + (MultiMarginLossConfig, {"weight": [1, -1, 2]}), + ], +) +def test_validation_fail(config, args): with pytest.raises(ValidationError): - LossConfig(loss="BCELoss", weight=[1, 2, 3]) + config(**args) + + +@pytest.mark.parametrize( + "config,args", + [ + (L1LossConfig, {"reduction": "mean"}), + (MSELossConfig, {"reduction": "mean"}), + ( + CrossEntropyLossConfig, + { + "reduction": "mean", + "weight": [1, 0, 2], + "ignore_index": 1, + "label_smoothing": 0.5, + }, + ), + (NLLLossConfig, {"reduction": "mean", "weight": [1, 0, 2], "ignore_index": 1}), + (KLDivLossConfig, {"reduction": "mean", "log_target": True}), + (BCELossConfig, {"reduction": "sum", "weight": None}), + ( + BCEWithLogitsLossConfig, + {"reduction": "sum", "weight": None, "pos_weight": [[1, 0, 2]]}, + ), + (HuberLossConfig, {"reduction": "sum", "delta": 0.1}), + (SmoothL1LossConfig, {"reduction": "sum", "beta": 0.0}), + ( + MultiMarginLossConfig, + {"reduction": "sum", "p": 1, "margin": -0.1, "weight": [1, 0, 2]}, + ), + ], +) +def test_validation_pass(config, args): + c = config(**args) + for arg, value in args.items(): + assert getattr(c, arg) == value + - LossConfig(loss="BCELoss") - LossConfig(loss="BCEWithLogitsLoss", weight=None) +@pytest.mark.parametrize( + "name,config", + [ + ("BCELoss", BCELossConfig), + ("BCEWithLogitsLoss", BCEWithLogitsLossConfig), + ("CrossEntropyLoss", CrossEntropyLossConfig), + ("HuberLoss", HuberLossConfig), + ("KLDivLoss", KLDivLossConfig), + ("L1Loss", L1LossConfig), + ("MSELoss", MSELossConfig), + ("MultiMarginLoss", MultiMarginLossConfig), + ("NLLLoss", NLLLossConfig), + ("SmoothL1Loss", SmoothL1LossConfig), + ], +) +def test_create_loss_config(name, config): + assert create_loss_config(name) == config diff --git a/tests/unittests/losses/test_factory.py b/tests/unittests/losses/test_factory.py index 5ac786deb..e7d602346 100644 --- a/tests/unittests/losses/test_factory.py +++ b/tests/unittests/losses/test_factory.py @@ -1,15 +1,17 @@ from torch import Tensor from torch.nn import BCEWithLogitsLoss, MultiMarginLoss -from clinicadl.losses import ImplementedLoss, LossConfig, get_loss_function +from clinicadl.losses import ImplementedLoss, create_loss_config, get_loss_function def test_get_loss_function(): - for loss in [e.value for e in ImplementedLoss]: - config = LossConfig(loss=loss) - get_loss_function(config) + for loss in ImplementedLoss: + config = create_loss_config(loss=loss)() + _ = get_loss_function(config) - config = LossConfig(loss="MultiMarginLoss", reduction="sum", weight=[1, 2, 3], p=2) + config = create_loss_config("MultiMarginLoss")( + reduction="sum", weight=[1, 2, 3], p=2 + ) loss, updated_config = get_loss_function(config) assert isinstance(loss, MultiMarginLoss) assert loss.reduction == "sum" @@ -23,7 +25,7 @@ def test_get_loss_function(): assert updated_config.margin == 1.0 assert updated_config.weight == [1, 2, 3] - config = LossConfig(loss="BCEWithLogitsLoss", pos_weight=[1, 2, 3]) + config = create_loss_config("BCEWithLogitsLoss")(pos_weight=[1, 2, 3]) loss, updated_config = get_loss_function(config) assert isinstance(loss, BCEWithLogitsLoss) assert (loss.pos_weight == Tensor([1, 2, 3])).all() diff --git a/tests/unittests/optim/lr_scheduler/test_config.py b/tests/unittests/optim/lr_scheduler/test_config.py index 270ccc27c..dbf96ccc8 100644 --- a/tests/unittests/optim/lr_scheduler/test_config.py +++ b/tests/unittests/optim/lr_scheduler/test_config.py @@ -1,37 +1,114 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.lr_scheduler import LRSchedulerConfig - - -def test_LRSchedulerConfig(): - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", - mode="max", - patience=1, - threshold_mode="rel", - milestones=[4, 3, 2], - min_lr={"param_0": 1e-1, "ELSE": 1e-2}, - ) - assert config.scheduler == "ReduceLROnPlateau" - assert config.mode == "max" - assert config.patience == 1 - assert config.threshold_mode == "rel" - assert config.milestones == [2, 3, 4] - assert config.min_lr == {"param_0": 1e-1, "ELSE": 1e-2} - assert config.threshold == "DefaultFromLibrary" +from clinicadl.optim.lr_scheduler.config import ( + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + create_lr_scheduler_config, +) +BAD_INPUTS = { + "milestones": [3, 2, 4], + "gamma": 0, + "last_epoch": -2, + "step_size": 0, + "factor": 0, + "total_iters": 0, + "start_factor": 0, + "end_factor": 0, + "mode": "abc", + "patience": 0, + "threshold": -1, + "threshold_mode": "abc", + "cooldown": -1, + "eps": -0.1, + "min_lr": -0.1, +} + +GOOD_INPUTS = { + "milestones": [1, 4, 5], + "gamma": 0.1, + "last_epoch": -1, + "step_size": 1, + "factor": 0.1, + "total_iters": 1, + "start_factor": 0.1, + "end_factor": 0.2, + "mode": "min", + "patience": 1, + "threshold": 0, + "threshold_mode": "abs", + "cooldown": 0, + "eps": 0, + "min_lr": 0, +} + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - LRSchedulerConfig(last_epoch=-2) - with pytest.raises(ValueError): - LRSchedulerConfig(scheduler="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(mode="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(threshold_mode="abc") - with pytest.raises(ValidationError): - LRSchedulerConfig(milestones=[10, 10]) - with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="MultiStepLR") + config(**inputs) + + # test dict inputs for min_lr + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + with pytest.raises(ValidationError): + config(**inputs) + + +def test_validation_fail_special(): with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="StepLR") + MultiStepLRConfig(milestones=[0, 1]) + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_pass(config): + fields = config.model_fields + inputs = {key: value for key, value in GOOD_INPUTS.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + c = config(**inputs) + assert getattr(c, "min_lr") == inputs["min_lr"] + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("ConstantLR", ConstantLRConfig), + ("LinearLR", LinearLRConfig), + ("MultiStepLR", MultiStepLRConfig), + ("ReduceLROnPlateau", ReduceLROnPlateauConfig), + ("StepLR", StepLRConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_lr_scheduler_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/lr_scheduler/test_factory.py b/tests/unittests/optim/lr_scheduler/test_factory.py index 76df845cd..cffb3d138 100644 --- a/tests/unittests/optim/lr_scheduler/test_factory.py +++ b/tests/unittests/optim/lr_scheduler/test_factory.py @@ -6,7 +6,7 @@ from clinicadl.optim.lr_scheduler import ( ImplementedLRScheduler, - LRSchedulerConfig, + create_lr_scheduler_config, get_lr_scheduler, ) @@ -37,17 +37,12 @@ def test_get_lr_scheduler(): lr=10.0, ) - for scheduler in [e.value for e in ImplementedLRScheduler]: - if scheduler == "StepLR": - config = LRSchedulerConfig(scheduler=scheduler, step_size=1) - elif scheduler == "MultiStepLR": - config = LRSchedulerConfig(scheduler=scheduler, milestones=[1, 2, 3]) - else: - config = LRSchedulerConfig(scheduler=scheduler) - get_lr_scheduler(optimizer, config) + args = {"step_size": 1, "milestones": [1, 2]} + for scheduler in ImplementedLRScheduler: + config = create_lr_scheduler_config(scheduler=scheduler)(**args) + _ = get_lr_scheduler(optimizer, config) - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", + config = create_lr_scheduler_config(scheduler="ReduceLROnPlateau")( mode="max", factor=0.123, threshold=1e-1, @@ -83,7 +78,8 @@ def test_get_lr_scheduler(): scheduler, updated_config = get_lr_scheduler(optimizer, config) assert scheduler.min_lrs == [1.0, 1.0, 1.0] - config = LRSchedulerConfig() + # no lr scheduler + config = create_lr_scheduler_config(None)() scheduler, updated_config = get_lr_scheduler(optimizer, config) assert isinstance(scheduler, LambdaLR) assert updated_config.scheduler is None diff --git a/tests/unittests/optim/optimizer/test_config.py b/tests/unittests/optim/optimizer/test_config.py index e162fe58d..bf1dbcd8f 100644 --- a/tests/unittests/optim/optimizer/test_config.py +++ b/tests/unittests/optim/optimizer/test_config.py @@ -1,34 +1,118 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.optimizer import OptimizerConfig - - -def test_OptimizerConfig(): - config = OptimizerConfig( - optimizer="SGD", - lr=1e-3, - weight_decay={"param_0": 1e-3, "param_1": 1e-2}, - momentum={"param_1": 1e-1}, - lr_decay=1e-4, - ) - assert config.optimizer == "SGD" - assert config.lr == 1e-3 - assert config.weight_decay == {"param_0": 1e-3, "param_1": 1e-2} - assert config.momentum == {"param_1": 1e-1} - assert config.lr_decay == 1e-4 - assert config.alpha == "DefaultFromLibrary" - assert sorted(config.get_all_groups()) == ["param_0", "param_1"] +from clinicadl.optim.optimizer.config import ( + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + create_optimizer_config, +) +BAD_INPUTS = { + "lr": 0, + "rho": 1.1, + "eps": -0.1, + "weight_decay": -0.1, + "lr_decay": -0.1, + "initial_accumulator_value": -0.1, + "betas": (0.9, 1.0), + "alpha": 1.1, + "momentum": -0.1, + "dampening": 0.1, +} + +GOOD_INPUTS_1 = { + "lr": 0.1, + "rho": 0, + "eps": 0, + "weight_decay": 0, + "foreach": None, + "capturable": False, + "maximize": True, + "differentiable": False, + "fused": None, + "lr_decay": 0, + "initial_accumulator_value": 0, + "betas": (0.0, 0.0), + "amsgrad": True, + "alpha": 0.0, + "momentum": 0, + "centered": True, + "dampening": 0, + "nesterov": True, +} + +GOOD_INPUTS_2 = { + "foreach": True, + "fused": False, +} + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - OptimizerConfig(betas={"params_0": (0.9, 1.01), "params_1": (0.9, 0.99)}) - with pytest.raises(ValidationError): - OptimizerConfig(betas=0.9) - with pytest.raises(ValidationError): - OptimizerConfig(rho=1.01) - with pytest.raises(ValidationError): - OptimizerConfig(alpha=1.01) + config(**inputs) + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} with pytest.raises(ValidationError): - OptimizerConfig(dampening={"params_0": 0.1, "params_1": 2}) - with pytest.raises(ValueError): - OptimizerConfig(optimizer="abc") + config(**inputs) + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +@pytest.mark.parametrize( + "good_inputs", + [ + GOOD_INPUTS_1, + GOOD_INPUTS_2, + ], +) +def test_validation_pass(config, good_inputs): + fields = config.model_fields + inputs = {key: value for key, value in good_inputs.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("Adadelta", AdadeltaConfig), + ("Adagrad", AdagradConfig), + ("Adam", AdamConfig), + ("RMSprop", RMSpropConfig), + ("SGD", SGDConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_optimizer_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 7dbf149e9..47b44a00a 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -1,12 +1,25 @@ from collections import OrderedDict import pytest +import torch import torch.nn as nn +from torch.optim import Adagrad + +from clinicadl.optim.optimizer import ( + ImplementedOptimizer, + create_optimizer_config, + get_optimizer, +) +from clinicadl.optim.optimizer.factory import ( + _get_params_in_group, + _get_params_not_in_group, + _regroup_args, +) @pytest.fixture def network(): - network = nn.Sequential( + net = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 1, kernel_size=3)), @@ -14,31 +27,22 @@ def network(): ] ) ) - network.add_module( + net.add_module( "final", nn.Sequential( OrderedDict([("dense2", nn.Linear(10, 5)), ("dense3", nn.Linear(5, 3))]) ), ) - return network + return net def test_get_optimizer(network): - from torch.optim import Adagrad - - from clinicadl.optim.optimizer import ( - ImplementedOptimizer, - OptimizerConfig, - get_optimizer, - ) - - for optimizer in [e.value for e in ImplementedOptimizer]: - config = OptimizerConfig(optimizer=optimizer) + for optimizer in ImplementedOptimizer: + config = create_optimizer_config(optimizer=optimizer)() optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 - config = OptimizerConfig( - optimizer="Adagrad", + config = create_optimizer_config("Adagrad")( lr=1e-5, weight_decay={"final.dense3.weight": 1.0, "dense1": 0.1}, lr_decay={"dense1": 10, "ELSE": 100}, @@ -82,28 +86,23 @@ def test_get_optimizer(network): assert not updated_config.maximize assert not updated_config.differentiable - # special case : only ELSE - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 1 + config = create_optimizer_config("Adagrad")( lr_decay={"ELSE": 100}, ) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 assert optimizer.param_groups[0]["lr_decay"] == 100 - # special case : the params mentioned form all the network - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 2 + config = create_optimizer_config("Adagrad")( lr_decay={"conv1": 100, "dense1": 10, "final": 1}, ) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 3 # special case : no ELSE mentioned - config = OptimizerConfig( - optimizer="Adagrad", - lr_decay={"conv1": 100}, - ) + config = create_optimizer_config("Adagrad")(lr_decay={"conv1": 100}) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr_decay"] == 100 @@ -111,8 +110,6 @@ def test_get_optimizer(network): def test_regroup_args(): - from clinicadl.optim.optimizer.factory import _regroup_args - args = { "weight_decay": {"params_0": 0.0, "params_1": 1.0}, "alpha": {"params_1": 0.5, "ELSE": 0.1}, @@ -134,3 +131,48 @@ def test_regroup_args(): {"weight_decay": {"params_0": 0.0, "params_1": 1.0}} ) assert len(args_global) == 0 + + +def test_get_params_in_block(network): + generator, list_layers = _get_params_in_group(network, "dense1") + assert next(iter(generator)).shape == torch.Size((10, 10)) + assert next(iter(generator)).shape == torch.Size((10,)) + assert sorted(list_layers) == sorted(["dense1.weight", "dense1.bias"]) + + generator, list_layers = _get_params_in_group(network, "dense1.weight") + assert next(iter(generator)).shape == torch.Size((10, 10)) + assert sum(1 for _ in generator) == 0 + assert sorted(list_layers) == sorted(["dense1.weight"]) + + generator, list_layers = _get_params_in_group(network, "final.dense3") + assert next(iter(generator)).shape == torch.Size((3, 5)) + assert next(iter(generator)).shape == torch.Size((3,)) + assert sorted(list_layers) == sorted(["final.dense3.weight", "final.dense3.bias"]) + + generator, list_layers = _get_params_in_group(network, "final") + assert sum(1 for _ in generator) == 4 + assert sorted(list_layers) == sorted( + [ + "final.dense2.weight", + "final.dense2.bias", + "final.dense3.weight", + "final.dense3.bias", + ] + ) + + +def test_find_params_not_in_group(network): + params = _get_params_not_in_group( + network, + [ + "final.dense2.weight", + "final.dense2.bias", + "conv1.bias", + "final.dense3.weight", + "dense1.weight", + "dense1.bias", + ], + ) + assert next(iter(params)).shape == torch.Size((1, 1, 3, 3)) + assert next(iter(params)).shape == torch.Size((3,)) + assert sum(1 for _ in params) == 0 # no more params