Skip to content

Commit

Permalink
Merge branch 'refactoring' into metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Sep 24, 2024
2 parents da46988 + 048ca39 commit 23f0353
Show file tree
Hide file tree
Showing 19 changed files with 1,078 additions and 339 deletions.
3 changes: 2 additions & 1 deletion clinicadl/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ClassificationLoss, ImplementedLoss, LossConfig
from .config import create_loss_config
from .enum import ClassificationLoss, ImplementedLoss
from .factory import get_loss_function
278 changes: 201 additions & 77 deletions clinicadl/losses/config.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,87 @@
from enum import Enum
from typing import List, Optional, Union
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
PositiveFloat,
computed_field,
field_validator,
model_validator,
)

from clinicadl.utils.enum import BaseEnum
from clinicadl.utils.factories import DefaultFromLibrary

from .enum import ImplementedLoss, Order, Reduction

class ClassificationLoss(str, BaseEnum):
"""Losses that can be used only for classification."""
__all__ = [
"LossConfig",
"NLLLossConfig",
"CrossEntropyLossConfig",
"BCELossConfig",
"BCEWithLogitsLossConfig",
"MultiMarginLossConfig",
"KLDivLossConfig",
"HuberLossConfig",
"SmoothL1LossConfig",
"L1LossConfig",
"MSELossConfig",
"create_loss_config",
]

CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel)
MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel)
BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape
BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically

class LossConfig(BaseModel, ABC):
"""Base config class for the loss function."""

class ImplementedLoss(str, Enum):
"""Implemented losses in ClinicaDL."""

CROSS_ENTROPY = "CrossEntropyLoss"
MULTI_MARGIN = "MultiMarginLoss"
BCE = "BCELoss"
BCE_LOGITS = "BCEWithLogitsLoss"
L1 = "L1Loss"
MSE = "MSELoss"
HUBER = "HuberLoss"
SMOOTH_L1 = "SmoothL1Loss"
KLDIV = "KLDivLoss" # if log_target=False, target must be positive

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented losses are: "
+ ", ".join([repr(m.value) for m in cls])
)
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
@abstractmethod
def loss(self) -> ImplementedLoss:
"""ImplementedLoss.e name of the loss."""

class Reduction(str, Enum):
"""Supported reduction method in ClinicaDL."""

MEAN = "mean"
SUM = "sum"
class NLLLossConfig(LossConfig):
"""Config class for Negative Log Likelihood loss."""

ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES

class Order(int, Enum):
"""Supported order of L-norm for MultiMarginLoss."""
@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.NLL

ONE = 1
TWO = 2
@field_validator("ignore_index")
@classmethod
def validator_ignore_index(cls, v):
if isinstance(v, int):
assert (
v == -100 or 0 <= v
), "ignore_index must be a positive int (or -100 when disabled)."
return v


class LossConfig(BaseModel):
"""Config class to configure the loss function."""
class CrossEntropyLossConfig(NLLLossConfig):
"""Config class for Cross Entropy loss."""

loss: ImplementedLoss = ImplementedLoss.MSE
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
beta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES # a weight for each class
ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES
label_smoothing: Union[
NonNegativeFloat, DefaultFromLibrary
] = DefaultFromLibrary.YES
log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
pos_weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES # a positive weight for each class
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.CROSS_ENTROPY

@field_validator("label_smoothing")
@classmethod
Expand All @@ -92,26 +92,150 @@ def validator_label_smoothing(cls, v):
), f"label_smoothing must be between 0 and 1 but it has been set to {v}."
return v

@field_validator("ignore_index")

class BCELossConfig(LossConfig):
"""Config class for Binary Cross Entropy loss."""

weight: Optional[List[NonNegativeFloat]] = None

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE

@field_validator("weight")
@classmethod
def validator_ignore_index(cls, v):
if isinstance(v, int):
assert (
v == -100 or 0 <= v
), "ignore_index must be a positive int (or -100 when disabled)."
def validator_weight(cls, v):
if v is not None:
raise ValueError(
"Cannot use weight with BCEWithLogitsLoss. If you want more flexibility, please use API mode."
)
return v

@model_validator(mode="after")
def model_validator(self):
if (
self.loss == ImplementedLoss.BCE_LOGITS
and self.weight is not None
and self.weight != DefaultFromLibrary.YES
):
raise ValueError("Cannot use weight with BCEWithLogitsLoss.")
elif (
self.loss == ImplementedLoss.BCE
and self.weight is not None
and self.weight != DefaultFromLibrary.YES
):
raise ValueError("Cannot use weight with BCELoss.")

class BCEWithLogitsLossConfig(BCELossConfig):
"""Config class for Binary Cross Entropy With Logits loss."""

pos_weight: Union[Optional[List[Any]], DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE_LOGITS

@field_validator("pos_weight")
@classmethod
def validator_pos_weight(cls, v):
if isinstance(v, list):
check = cls._recursive_float_check(v)
if not check:
raise ValueError(
f"elements in pos_weight must be non-negative float, got: {v}"
)
return v

@classmethod
def _recursive_float_check(cls, item):
if isinstance(item, list):
return all(cls._recursive_float_check(i) for i in item)
else:
return (isinstance(item, float) or isinstance(item, int)) and item >= 0


class MultiMarginLossConfig(LossConfig):
"""Config class for Multi Margin loss."""

p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MULTI_MARGIN


class KLDivLossConfig(LossConfig):
"""Config class for Kullback-Leibler Divergence loss."""

log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.KLDIV


class HuberLossConfig(LossConfig):
"""Config class for Huber loss."""

delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.HUBER


class SmoothL1LossConfig(LossConfig):
"""Config class for Smooth L1 loss."""

beta: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.SMOOTH_L1


class L1LossConfig(LossConfig):
"""Config class for L1 loss."""

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.L1


class MSELossConfig(LossConfig):
"""Config class for Mean Squared Error loss."""

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MSE


def create_loss_config(
loss: Union[str, ImplementedLoss],
) -> Type[LossConfig]:
"""
A factory function to create a config class suited for the loss.
Parameters
----------
loss : Union[str, ImplementedLoss]
The name of the loss.
Returns
-------
Type[LossConfig]
The config class.
Raises
------
ValueError
If `loss` is not supported.
"""
loss = ImplementedLoss(loss)
config_name = "".join([loss, "Config"])
config = globals()[config_name]

return config
50 changes: 50 additions & 0 deletions clinicadl/losses/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from enum import Enum

from clinicadl.utils.enum import BaseEnum


class ClassificationLoss(str, BaseEnum):
"""Losses that can be used only for classification."""

CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel)
NLL = "NLLLoss" # for multi-class classification, inputs are log-probabilities and targets are int (same dimension without the class channel)
MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel)
BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape
BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically


class ImplementedLoss(str, Enum):
"""Implemented losses in ClinicaDL."""

CROSS_ENTROPY = "CrossEntropyLoss"
NLL = "NLLLoss"
MULTI_MARGIN = "MultiMarginLoss"
BCE = "BCELoss"
BCE_LOGITS = "BCEWithLogitsLoss"

L1 = "L1Loss"
MSE = "MSELoss"
HUBER = "HuberLoss"
SMOOTH_L1 = "SmoothL1Loss"
KLDIV = "KLDivLoss" # if log_target=False, target must be positive

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented losses are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Reduction(str, Enum):
"""Supported reduction method in ClinicaDL."""

MEAN = "mean"
SUM = "sum"


class Order(int, Enum):
"""Supported order of L-norm for MultiMarginLoss."""

ONE = 1
TWO = 2
2 changes: 1 addition & 1 deletion clinicadl/losses/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, LossConfig]:
config_dict_["pos_weight"] = torch.Tensor(config_dict_["pos_weight"])
loss = loss_class(**config_dict_)

updated_config = LossConfig(loss=config.loss, **config_dict)
updated_config = config.model_copy(update=config_dict)

return loss, updated_config
3 changes: 3 additions & 0 deletions clinicadl/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .config import OptimizationConfig
from .early_stopping import EarlyStopping
from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler
from .optimizer import create_optimizer_config, get_optimizer
3 changes: 2 additions & 1 deletion clinicadl/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ImplementedLRScheduler, LRSchedulerConfig
from .config import create_lr_scheduler_config
from .enum import ImplementedLRScheduler
from .factory import get_lr_scheduler
Loading

0 comments on commit 23f0353

Please sign in to comment.