Skip to content

Commit

Permalink
Merge remote-tracking branch 'ancestor-mithril/dev2'
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 5, 2023
2 parents 6239882 + e9f14e7 commit 8212172
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 41 deletions.
7 changes: 3 additions & 4 deletions nnunetv2/training/loss/compound_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,18 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor):
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = (target != self.ignore_label).bool()
mask = target != self.ignore_label
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.clone(target)
target_dice[target == self.ignore_label] = 0
target_dice = torch.where(mask, target, 0)
num_fg = mask.sum()
else:
target_dice = target
mask = None

dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target[:, 0].long()) \
ce_loss = self.ce(net_output, target[:, 0]) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0

result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
Expand Down
23 changes: 9 additions & 14 deletions nnunetv2/training/loss/deep_supervision.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch import nn


Expand All @@ -11,25 +12,19 @@ def __init__(self, loss, weight_factors=None):
If weights are None, all w will be 1.
"""
super(DeepSupervisionWrapper, self).__init__()
self.weight_factors = weight_factors
assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0"
self.weight_factors = tuple(weight_factors)
self.loss = loss

def forward(self, *args):
for i in args:
assert isinstance(i, (tuple, list)), f"all args must be either tuple or list, got {type(i)}"
# we could check for equal lengths here as well but we really shouldn't overdo it with checks because
# this code is executed a lot of times!
assert all([isinstance(i, (tuple, list)) for i in args]), \
f"all args must be either tuple or list, got {[type(i) for i in args]}"
# we could check for equal lengths here as well, but we really shouldn't overdo it with checks because
# this code is executed a lot of times!

if self.weight_factors is None:
weights = [1] * len(args[0])
weights = (1, ) * len(args[0])
else:
weights = self.weight_factors

# we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's
# really necessary
l = weights[0] * self.loss(*[j[0] for j in args])
for i, inputs in enumerate(zip(*args)):
if i == 0:
continue
l += weights[i] * self.loss(*inputs)
return l
return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0])
42 changes: 21 additions & 21 deletions nnunetv2/training/loss/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def forward(self, x, y, loss_mask=None):
x = self.apply_nonlin(x)

# make everything shape (b, c)
axes = list(range(2, len(x.shape)))
axes = tuple(range(2, x.ndim))

with torch.no_grad():
if len(x.shape) != len(y.shape):
if x.ndim != y.ndim:
y = y.view((y.shape[0], 1, *y.shape[1:]))

if x.shape == y.shape:
# if this is the case then gt is probably already a one hot encoding
y_onehot = y
else:
gt = y.long()
y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool)
y_onehot.scatter_(1, gt, 1)
y_onehot.scatter_(1, y.long(), 1)

if not self.do_bg:
y_onehot = y_onehot[:, 1:]
Expand All @@ -96,15 +96,19 @@ def forward(self, x, y, loss_mask=None):
if not self.do_bg:
x = x[:, 1:]

intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes)
sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes)

if self.ddp and self.batch_dice:
intersect = AllGatherGrad.apply(intersect).sum(0)
sum_pred = AllGatherGrad.apply(sum_pred).sum(0)
sum_gt = AllGatherGrad.apply(sum_gt).sum(0)
if loss_mask is None:
intersect = (x * y_onehot).sum(axes)
sum_pred = x.sum(axes)
else:
intersect = (x * y_onehot * loss_mask).sum(axes)
sum_pred = (x * loss_mask).sum(axes)

if self.batch_dice:
if self.ddp:
intersect = AllGatherGrad.apply(intersect).sum(0)
sum_pred = AllGatherGrad.apply(sum_pred).sum(0)
sum_gt = AllGatherGrad.apply(sum_gt).sum(0)

intersect = intersect.sum(0)
sum_pred = sum_pred.sum(0)
sum_gt = sum_gt.sum(0)
Expand All @@ -128,22 +132,18 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))

shp_x = net_output.shape
shp_y = gt.shape
axes = tuple(range(2, net_output.ndim))

with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if net_output.ndim != gt.ndim:
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))

if net_output.shape == gt.shape:
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x, device=net_output.device)
y_onehot.scatter_(1, gt, 1)
y_onehot = torch.zeros(net_output.shape, device=net_output.device)
y_onehot.scatter_(1, gt.long(), 1)

tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
Expand All @@ -152,7 +152,7 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):

if mask is not None:
with torch.no_grad():
mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))]))
mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)]))
tp *= mask_here
fp *= mask_here
fn *= mask_here
Expand Down
3 changes: 1 addition & 2 deletions nnunetv2/training/loss/robust_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
input must be logits, not probabilities!
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if len(target.shape) == len(input.shape):
if target.ndim == input.ndim:
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target.long())
Expand All @@ -30,4 +30,3 @@ def forward(self, inp, target):
num_voxels = np.prod(res.shape, dtype=np.int64)
res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
return res.mean()

0 comments on commit 8212172

Please sign in to comment.