Skip to content

Commit

Permalink
Merge pull request #66 from aecelaya/main
Browse files Browse the repository at this point in the history
Add learning rate to train progress bar. Minor performance improvements.
  • Loading branch information
aecelaya authored Dec 3, 2024
2 parents dd192b5 + e5dab93 commit 275d5c6
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 175 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime

# Set environment variables for non-interactive installation.
ENV DEBIAN_FRONTEND=noninteractive
Expand Down
2 changes: 1 addition & 1 deletion mist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(arguments: argparse.Namespace) -> None:
# Check if the number of folds is compatible with the number of folds
# specified.
if (
np.max(mist_arguments.folds) + 1 < mist_arguments.nfolds or
np.max(mist_arguments.folds) + 1 > mist_arguments.nfolds or
len(mist_arguments.folds) > mist_arguments.nfolds
):
raise AssertionError(
Expand Down
30 changes: 14 additions & 16 deletions mist/runtime/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ def get_main_args():
type=int,
help="Height, width, and depth of patch size"
)
parser.arg(
"--validate-every-n-epochs",
type=positive_int,
help="Validate every n epochs"
)
parser.arg(
"--validate-after-n-epochs",
type=int,
default=1,
help="Start validation after n epochs. If -1, validate only on the last epoch"
)
parser.arg(
"--max-patch-size",
default=[256, 256, 256],
Expand All @@ -227,12 +238,6 @@ def get_main_args():
default=3e-4,
help="Learning rate"
)
parser.arg(
"--exp_decay",
type=positive_float,
default=0.9999,
help="Exponential decay factor"
)
parser.arg(
"--lr-scheduler",
type=str,
Expand All @@ -241,17 +246,10 @@ def get_main_args():
"constant",
"polynomial",
"cosine",
"cosine_warm_restarts",
"exponential"
"cosine-warm-restarts",
],
help="Learning rate scheduler"
)
parser.arg(
"--cosine-first-steps",
type=positive_int,
default=500,
help="Length of a cosine decay cycle in steps, only with cosine_annealing scheduler"
)

# Optimizer parameters.
parser.arg(
Expand Down Expand Up @@ -380,9 +378,9 @@ def get_main_args():
parser.arg(
"--loss",
type=str,
default="dice_ce",
default="dice-ce",
choices=[
"dice_ce",
"dice-ce",
"dice",
"bl",
"hdl",
Expand Down
21 changes: 20 additions & 1 deletion mist/runtime/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,30 @@

class InsufficientValidationSetError(Exception):
"""Raised if validation set size is smaller than the number of GPUs."""

def __init__(self, val_size: int, world_size: int) -> None:
self.message = (
f"Validation set size of {val_size} is too small for {world_size} "
"GPUs. Please increase the validation set size or reduce the "
"number of GPUs."
)
super().__init__(self.message)


class NaNLossError(Exception):
"""Raised if the loss is NaN."""
def __init__(self, epoch) -> None:
self.message = (
f"Encountered NaN loss value in epoch {epoch}. Stopping training. "
"Consider using a different optimizer, reducing the learning rate, "
"or using gradient clipping."
)
super().__init__(self.message)


class NoGPUsAvailableError(Exception):
"""Raised if no GPU is available."""
def __init__(self) -> None:
self.message = (
"No GPU available. Please check your hardware configuration."
)
super().__init__(self.message)
82 changes: 80 additions & 2 deletions mist/runtime/loss_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Loss function implementations for training segmentation models."""
from typing import Tuple
from typing import Tuple, Optional, Callable

import argparse
import torch
Expand All @@ -9,6 +9,84 @@
from mist.runtime import loss_utils


class DeepSupervisionLoss(nn.Module):
"""Loss function for deep supervision in segmentation tasks.
This class calculates the loss for the main output and additional deep
supervision heads using a geometric weighting scheme. Deep supervision
provides intermediate outputs during training to guide the model's learning
at multiple stages.
Attributes:
loss_fn: The base loss function to apply (e.g., Dice loss).
scaling_fn: A function to scale the loss for each supervision head.
Defaults to geometric scaling by 0.5 ** k, where k is the index.
"""
def __init__(
self,
loss_fn: nn.Module,
scaling_fn: Optional[Callable[[int], float]]=None
):
super().__init__()
self.loss_fn = loss_fn
self.scaling_fn = scaling_fn or (lambda k: 0.5 ** k)

def apply_loss(self, y_true, y_pred, alpha=None, dtm=None):
"""Applies the configured loss function with appropriate arguments."""
if dtm is not None:
return self.loss_fn(y_true, y_pred, dtm, alpha)
elif alpha is not None:
return self.loss_fn(y_true, y_pred, alpha)
return self.loss_fn(y_true, y_pred)

def forward(
self,
y_true: torch.Tensor,
y_pred: torch.Tensor,
y_supervision: Optional[Tuple[torch.Tensor, ...]] = None,
alpha: Optional[float] = None,
dtm: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Computes the total loss, including contributions from deep supervision.
Args:
y_true: Ground truth mask of shape (batch_size, 1, height, width,
depth). This mask is not one-hot encoded because of the we way
we construct the data loader. The one-hot encoding is applied
in the forward pass of the loss function.
y_pred: Predicted main output of shape (batch_size, num_classes,
height, width, depth). This is the main output of the network.
We assume that the predicted mask is the raw output of a network
that has not been passed through a softmax function. We apply
the softmax function in the forward pass of the loss function.
y_supervision (optional): Deep supervision outputs, each of shape
(batch_size, num_classes, height, width, depth). Like y_pred,
these are raw outputs of the network. We apply the softmax
function in the forward pass of the loss function.
alpha (optional): Balances region and boundary losses. This is a
hyperparameter that should be in the interval [0, 1].
dtm (optional): Distance transform maps for boundary-based loss.
Returns:
The total weighted loss.
"""
# Collect main prediction and deep supervision outputs.
_y_pred = [y_pred] + (list(y_supervision) if y_supervision else [])

# Compute weighted loss.
losses = torch.stack(
[
self.scaling_fn(k) * self.apply_loss(y_true, pred, alpha, dtm)
for k, pred in enumerate(_y_pred)
]
)

# Normalize using the sum of the scaling factors.
normalization = sum(self.scaling_fn(k) for k in range(len(_y_pred)))
return losses.sum() / normalization


class DiceLoss(nn.Module):
"""Soft Dice loss function for segmentation tasks.
Expand Down Expand Up @@ -638,7 +716,7 @@ def get_loss(args: argparse.Namespace) -> nn.Module:
"""
if args.loss == "dice":
return DiceLoss(exclude_background=args.exclude_background)
if args.loss == "dice_ce":
if args.loss == "dice-ce":
return DiceCELoss(exclude_background=args.exclude_background)
if args.loss == "bl":
return BoundaryLoss(exclude_background=args.exclude_background)
Expand Down
1 change: 0 additions & 1 deletion mist/runtime/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,3 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:
The soft skeleton of the input tensor.
"""
return self.soft_skeletonize(img)

21 changes: 16 additions & 5 deletions mist/runtime/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Progress bars for MIST training and validation loops."""
from rich.progress import (
BarColumn,
TextColumn,
Progress,
MofNCompleteColumn,
TimeElapsedColumn
)

import numpy as np

class TrainProgressBar(Progress):
def __init__(self, current_epoch, fold, epochs, train_steps):
Expand All @@ -20,16 +21,26 @@ def __init__(self, current_epoch, fold, epochs, train_steps):
TimeElapsedColumn(),
TextColumn("•"),
TextColumn("{task.fields[loss]}"),
TextColumn("•"),
TextColumn("{task.fields[lr]}"), # Learning rate column
)

# Initialize tasks with loss and learning rate fields
self.task = self.progress.add_task(
description="Training",
description="Training (loss)",
total=train_steps,
loss=f"loss: "
loss=f"loss: ",
lr=f"lr: ",
)

def update(self, loss):
self.progress.update(self.task, advance=1, loss=f"loss: {loss:.4f}")
def update(self, loss, lr):
# Update loss task.
self.progress.update(
self.task,
advance=1,
loss=f"loss: {loss:.4f}",
lr=f"lr: {np.format_float_scientific(lr, precision=3)}",
)

def __enter__(self):
self.progress.start()
Expand Down
Loading

0 comments on commit 275d5c6

Please sign in to comment.