diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index fa7fe344..222bd4d9 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -57,10 +57,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.01, 0.1, 1.5] optimizer: class_path: torch.optim.AdamW init_args: diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index 236cd70e..44354672 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -57,10 +57,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] optimizer: class_path: torch.optim.AdamW init_args: diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 2bbcb36f..82b40b95 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -49,10 +49,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.01, 0.1, 1.5] lr_multiplier_encoder: 0.0 optimizer: class_path: torch.optim.AdamW diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index 85ae15a7..1858922c 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -49,10 +49,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] lr_multiplier_encoder: 0.0 optimizer: class_path: torch.optim.AdamW diff --git a/src/eva/core/losses/__init__.py b/src/eva/core/losses/__init__.py new file mode 100644 index 00000000..1ea65fdd --- /dev/null +++ b/src/eva/core/losses/__init__.py @@ -0,0 +1,5 @@ +"""Loss functions API.""" + +from eva.core.losses.cross_entropy import CrossEntropyLoss + +__all__ = ["CrossEntropyLoss"] diff --git a/src/eva/core/losses/cross_entropy.py b/src/eva/core/losses/cross_entropy.py new file mode 100644 index 00000000..0101ad0d --- /dev/null +++ b/src/eva/core/losses/cross_entropy.py @@ -0,0 +1,27 @@ +"""Cross-entropy based loss function.""" + +from typing import Sequence + +import torch +from torch import nn + + +class CrossEntropyLoss(nn.CrossEntropyLoss): + """A wrapper around torch.nn.CrossEntropyLoss that accepts weights in list format. + + Needed for .yaml file loading & class instantiation with jsonarparse. + """ + + def __init__( + self, *args, weight: Sequence[float] | torch.Tensor | None = None, **kwargs + ) -> None: + """Initialize the loss function. + + Args: + args: Positional arguments from the base class. + weight: A list of weights to assign to each class. + kwargs: Key-word arguments from the base class. + """ + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + super().__init__(*args, **kwargs, weight=weight) diff --git a/src/eva/vision/losses/__init__.py b/src/eva/vision/losses/__init__.py index ff791819..ff8afbc0 100644 --- a/src/eva/vision/losses/__init__.py +++ b/src/eva/vision/losses/__init__.py @@ -1,5 +1,5 @@ """Loss functions API.""" -from eva.vision.losses.dice import DiceLoss +from eva.vision.losses.dice import DiceCELoss, DiceLoss -__all__ = ["DiceLoss"] +__all__ = ["DiceLoss", "DiceCELoss"] diff --git a/src/eva/vision/losses/dice.py b/src/eva/vision/losses/dice.py index 79416040..8e6133b3 100644 --- a/src/eva/vision/losses/dice.py +++ b/src/eva/vision/losses/dice.py @@ -1,4 +1,6 @@ -"""Dice loss.""" +"""Dice based loss functions.""" + +from typing import Sequence, Tuple import torch from monai import losses @@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore Extends the implementation from MONAI - to support semantic target labels (meaning targets of shape BHW) - to support `ignore_index` functionality + - accept weight argument in list format """ - def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None: - """Initialize the DiceLoss with support for ignore_index. + def __init__( + self, + *args, + ignore_index: int | None = None, + weight: Sequence[float] | torch.Tensor | None = None, + **kwargs, + ) -> None: + """Initialize the DiceLoss. Args: args: Positional arguments from the base class. ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. + weight: A list of weights to assign to each class. kwargs: Key-word arguments from the base class. """ - super().__init__(*args, **kwargs) + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + + super().__init__(*args, **kwargs, weight=weight) self.ignore_index = ignore_index @override def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa - if self.ignore_index is not None: - mask = targets != self.ignore_index - targets = targets * mask - inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask) + inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) + targets = _to_one_hot(targets, num_classes=inputs.shape[1]) if targets.ndim == 3: targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1]) return super().forward(inputs, targets) + + +class DiceCELoss(losses.dice.DiceCELoss): + """Combination of Dice and Cross Entropy Loss. + + Extends the implementation from MONAI + - to support semantic target labels (meaning targets of shape BHW) + - to support `ignore_index` functionality + - accept weight argument in list format + """ + + def __init__( + self, + *args, + ignore_index: int | None = None, + weight: Sequence[float] | torch.Tensor | None = None, + **kwargs, + ) -> None: + """Initialize the DiceCELoss. + + Args: + args: Positional arguments from the base class. + ignore_index: Specifies a target value that is ignored and + does not contribute to the input gradient. + weight: A list of weights to assign to each class. + kwargs: Key-word arguments from the base class. + """ + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + + super().__init__(*args, **kwargs, weight=weight) + + self.ignore_index = ignore_index + + @override + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa + inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) + targets = _to_one_hot(targets, num_classes=inputs.shape[1]) + + return super().forward(inputs, targets) + + +def _apply_ignore_index( + inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None +) -> Tuple[torch.Tensor, torch.Tensor]: + if ignore_index is not None: + mask = targets != ignore_index + targets = targets * mask + inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask) + return inputs, targets + + +def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor: + if tensor.ndim == 3: + return one_hot(tensor[:, None, ...], num_classes=num_classes) + return tensor