diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index fa7fe344..d9e0c490 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -12,7 +12,7 @@ trainer: refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: - log_every_n_steps: 1000 + log_every_n_epochs: 1 log_images: false - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: @@ -57,10 +57,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] optimizer: class_path: torch.optim.AdamW init_args: @@ -119,6 +118,7 @@ data: class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: size: ${oc.env:RESIZE_DIM, 224} + clamp_range: [-1008, 822] mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - class_path: eva.vision.datasets.LiTS @@ -137,6 +137,7 @@ data: val: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS + shuffle: true test: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index 236cd70e..a0059e34 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -3,7 +3,7 @@ trainer: class_path: eva.Trainer init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} callbacks: - class_path: eva.callbacks.ConfigurationLogger @@ -29,7 +29,7 @@ trainer: mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter init_args: - output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced dataloader_idx_map: 0: train 1: val @@ -57,10 +57,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] optimizer: class_path: torch.optim.AdamW init_args: @@ -119,6 +118,7 @@ data: class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: size: ${oc.env:RESIZE_DIM, 224} + clamp_range: [-1008, 822] mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - class_path: eva.vision.datasets.LiTSBalanced diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 2bbcb36f..3d8d2fc5 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -49,10 +49,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] lr_multiplier_encoder: 0.0 optimizer: class_path: torch.optim.AdamW @@ -96,6 +95,7 @@ data: class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: size: ${oc.env:RESIZE_DIM, 224} + clamp_range: [-1008, 822] mean: *NORMALIZE_MEAN std: *NORMALIZE_STD val: diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index 85ae15a7..cff4c88e 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -3,7 +3,7 @@ trainer: class_path: eva.Trainer init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 callbacks: @@ -49,10 +49,9 @@ model: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 3 criterion: - class_path: eva.vision.losses.DiceLoss + class_path: eva.core.losses.CrossEntropyLoss init_args: - softmax: true - batch: true + weight: [0.05, 0.1, 1.5] lr_multiplier_encoder: 0.0 optimizer: class_path: torch.optim.AdamW @@ -96,6 +95,7 @@ data: class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: size: ${oc.env:RESIZE_DIM, 224} + clamp_range: [-1008, 822] mean: *NORMALIZE_MEAN std: *NORMALIZE_STD val: diff --git a/src/eva/core/losses/__init__.py b/src/eva/core/losses/__init__.py new file mode 100644 index 00000000..1ea65fdd --- /dev/null +++ b/src/eva/core/losses/__init__.py @@ -0,0 +1,5 @@ +"""Loss functions API.""" + +from eva.core.losses.cross_entropy import CrossEntropyLoss + +__all__ = ["CrossEntropyLoss"] diff --git a/src/eva/core/losses/cross_entropy.py b/src/eva/core/losses/cross_entropy.py new file mode 100644 index 00000000..0101ad0d --- /dev/null +++ b/src/eva/core/losses/cross_entropy.py @@ -0,0 +1,27 @@ +"""Cross-entropy based loss function.""" + +from typing import Sequence + +import torch +from torch import nn + + +class CrossEntropyLoss(nn.CrossEntropyLoss): + """A wrapper around torch.nn.CrossEntropyLoss that accepts weights in list format. + + Needed for .yaml file loading & class instantiation with jsonarparse. + """ + + def __init__( + self, *args, weight: Sequence[float] | torch.Tensor | None = None, **kwargs + ) -> None: + """Initialize the loss function. + + Args: + args: Positional arguments from the base class. + weight: A list of weights to assign to each class. + kwargs: Key-word arguments from the base class. + """ + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + super().__init__(*args, **kwargs, weight=weight) diff --git a/src/eva/vision/losses/__init__.py b/src/eva/vision/losses/__init__.py index ff791819..ff8afbc0 100644 --- a/src/eva/vision/losses/__init__.py +++ b/src/eva/vision/losses/__init__.py @@ -1,5 +1,5 @@ """Loss functions API.""" -from eva.vision.losses.dice import DiceLoss +from eva.vision.losses.dice import DiceCELoss, DiceLoss -__all__ = ["DiceLoss"] +__all__ = ["DiceLoss", "DiceCELoss"] diff --git a/src/eva/vision/losses/dice.py b/src/eva/vision/losses/dice.py index 79416040..8e6133b3 100644 --- a/src/eva/vision/losses/dice.py +++ b/src/eva/vision/losses/dice.py @@ -1,4 +1,6 @@ -"""Dice loss.""" +"""Dice based loss functions.""" + +from typing import Sequence, Tuple import torch from monai import losses @@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore Extends the implementation from MONAI - to support semantic target labels (meaning targets of shape BHW) - to support `ignore_index` functionality + - accept weight argument in list format """ - def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None: - """Initialize the DiceLoss with support for ignore_index. + def __init__( + self, + *args, + ignore_index: int | None = None, + weight: Sequence[float] | torch.Tensor | None = None, + **kwargs, + ) -> None: + """Initialize the DiceLoss. Args: args: Positional arguments from the base class. ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. + weight: A list of weights to assign to each class. kwargs: Key-word arguments from the base class. """ - super().__init__(*args, **kwargs) + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + + super().__init__(*args, **kwargs, weight=weight) self.ignore_index = ignore_index @override def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa - if self.ignore_index is not None: - mask = targets != self.ignore_index - targets = targets * mask - inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask) + inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) + targets = _to_one_hot(targets, num_classes=inputs.shape[1]) if targets.ndim == 3: targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1]) return super().forward(inputs, targets) + + +class DiceCELoss(losses.dice.DiceCELoss): + """Combination of Dice and Cross Entropy Loss. + + Extends the implementation from MONAI + - to support semantic target labels (meaning targets of shape BHW) + - to support `ignore_index` functionality + - accept weight argument in list format + """ + + def __init__( + self, + *args, + ignore_index: int | None = None, + weight: Sequence[float] | torch.Tensor | None = None, + **kwargs, + ) -> None: + """Initialize the DiceCELoss. + + Args: + args: Positional arguments from the base class. + ignore_index: Specifies a target value that is ignored and + does not contribute to the input gradient. + weight: A list of weights to assign to each class. + kwargs: Key-word arguments from the base class. + """ + if weight is not None and not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight) + + super().__init__(*args, **kwargs, weight=weight) + + self.ignore_index = ignore_index + + @override + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa + inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) + targets = _to_one_hot(targets, num_classes=inputs.shape[1]) + + return super().forward(inputs, targets) + + +def _apply_ignore_index( + inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None +) -> Tuple[torch.Tensor, torch.Tensor]: + if ignore_index is not None: + mask = targets != ignore_index + targets = targets * mask + inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask) + return inputs, targets + + +def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor: + if tensor.ndim == 3: + return one_hot(tensor[:, None, ...], num_classes=num_classes) + return tensor