From e9f14e7bc8fb57f49baaa148afc6992163b847a8 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Tue, 10 Oct 2023 14:07:27 +0300 Subject: [PATCH] Making loss calculation faster --- nnunetv2/training/loss/compound_losses.py | 7 ++-- nnunetv2/training/loss/deep_supervision.py | 23 +++++------- nnunetv2/training/loss/dice.py | 42 +++++++++++----------- nnunetv2/training/loss/robust_ce_loss.py | 3 +- 4 files changed, 34 insertions(+), 41 deletions(-) diff --git a/nnunetv2/training/loss/compound_losses.py b/nnunetv2/training/loss/compound_losses.py index 9db0a4227..eaeb5d8e0 100644 --- a/nnunetv2/training/loss/compound_losses.py +++ b/nnunetv2/training/loss/compound_losses.py @@ -38,11 +38,10 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): if self.ignore_label is not None: assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ '(DC_and_CE_loss)' - mask = (target != self.ignore_label).bool() + mask = target != self.ignore_label # remove ignore label from target, replace with one of the known labels. It doesn't matter because we # ignore gradients in those areas anyway - target_dice = torch.clone(target) - target_dice[target == self.ignore_label] = 0 + target_dice = torch.where(mask, target, 0) num_fg = mask.sum() else: target_dice = target @@ -50,7 +49,7 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ if self.weight_dice != 0 else 0 - ce_loss = self.ce(net_output, target[:, 0].long()) \ + ce_loss = self.ce(net_output, target[:, 0]) \ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 result = self.weight_ce * ce_loss + self.weight_dice * dc_loss diff --git a/nnunetv2/training/loss/deep_supervision.py b/nnunetv2/training/loss/deep_supervision.py index 03141e809..952e3f715 100644 --- a/nnunetv2/training/loss/deep_supervision.py +++ b/nnunetv2/training/loss/deep_supervision.py @@ -1,3 +1,4 @@ +import torch from torch import nn @@ -11,25 +12,19 @@ def __init__(self, loss, weight_factors=None): If weights are None, all w will be 1. """ super(DeepSupervisionWrapper, self).__init__() - self.weight_factors = weight_factors + assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" + self.weight_factors = tuple(weight_factors) self.loss = loss def forward(self, *args): - for i in args: - assert isinstance(i, (tuple, list)), f"all args must be either tuple or list, got {type(i)}" - # we could check for equal lengths here as well but we really shouldn't overdo it with checks because - # this code is executed a lot of times! + assert all([isinstance(i, (tuple, list)) for i in args]), \ + f"all args must be either tuple or list, got {[type(i) for i in args]}" + # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because + # this code is executed a lot of times! if self.weight_factors is None: - weights = [1] * len(args[0]) + weights = (1, ) * len(args[0]) else: weights = self.weight_factors - # we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's - # really necessary - l = weights[0] * self.loss(*[j[0] for j in args]) - for i, inputs in enumerate(zip(*args)): - if i == 0: - continue - l += weights[i] * self.loss(*inputs) - return l \ No newline at end of file + return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) diff --git a/nnunetv2/training/loss/dice.py b/nnunetv2/training/loss/dice.py index af554908b..574435754 100644 --- a/nnunetv2/training/loss/dice.py +++ b/nnunetv2/training/loss/dice.py @@ -74,18 +74,18 @@ def forward(self, x, y, loss_mask=None): x = self.apply_nonlin(x) # make everything shape (b, c) - axes = list(range(2, len(x.shape))) + axes = tuple(range(2, x.ndim)) + with torch.no_grad(): - if len(x.shape) != len(y.shape): + if x.ndim != y.ndim: y = y.view((y.shape[0], 1, *y.shape[1:])) if x.shape == y.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = y else: - gt = y.long() y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool) - y_onehot.scatter_(1, gt, 1) + y_onehot.scatter_(1, y.long(), 1) if not self.do_bg: y_onehot = y_onehot[:, 1:] @@ -96,15 +96,19 @@ def forward(self, x, y, loss_mask=None): if not self.do_bg: x = x[:, 1:] - intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes) - sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes) - - if self.ddp and self.batch_dice: - intersect = AllGatherGrad.apply(intersect).sum(0) - sum_pred = AllGatherGrad.apply(sum_pred).sum(0) - sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + if loss_mask is None: + intersect = (x * y_onehot).sum(axes) + sum_pred = x.sum(axes) + else: + intersect = (x * y_onehot * loss_mask).sum(axes) + sum_pred = (x * loss_mask).sum(axes) if self.batch_dice: + if self.ddp: + intersect = AllGatherGrad.apply(intersect).sum(0) + sum_pred = AllGatherGrad.apply(sum_pred).sum(0) + sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + intersect = intersect.sum(0) sum_pred = sum_pred.sum(0) sum_gt = sum_gt.sum(0) @@ -128,22 +132,18 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): :return: """ if axes is None: - axes = tuple(range(2, len(net_output.size()))) - - shp_x = net_output.shape - shp_y = gt.shape + axes = tuple(range(2, net_output.ndim)) with torch.no_grad(): - if len(shp_x) != len(shp_y): - gt = gt.view((shp_y[0], 1, *shp_y[1:])) + if net_output.ndim != gt.ndim: + gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) if net_output.shape == gt.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: - gt = gt.long() - y_onehot = torch.zeros(shp_x, device=net_output.device) - y_onehot.scatter_(1, gt, 1) + y_onehot = torch.zeros(net_output.shape, device=net_output.device) + y_onehot.scatter_(1, gt.long(), 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) @@ -152,7 +152,7 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): if mask is not None: with torch.no_grad(): - mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))])) + mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) tp *= mask_here fp *= mask_here fn *= mask_here diff --git a/nnunetv2/training/loss/robust_ce_loss.py b/nnunetv2/training/loss/robust_ce_loss.py index ad4665919..3399e3ae9 100644 --- a/nnunetv2/training/loss/robust_ce_loss.py +++ b/nnunetv2/training/loss/robust_ce_loss.py @@ -10,7 +10,7 @@ class RobustCrossEntropyLoss(nn.CrossEntropyLoss): input must be logits, not probabilities! """ def forward(self, input: Tensor, target: Tensor) -> Tensor: - if len(target.shape) == len(input.shape): + if target.ndim == input.ndim: assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long()) @@ -30,4 +30,3 @@ def forward(self, inp, target): num_voxels = np.prod(res.shape, dtype=np.int64) res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) return res.mean() -