-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
compound_losses.py
156 lines (134 loc) · 6.06 KB
/
compound_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss
from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1
from torch import nn
class DC_and_CE_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None,
dice_class=SoftDiceLoss):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param aggregate:
:param square_dice:
:param weight_ce:
:param weight_dice:
"""
super(DC_and_CE_loss, self).__init__()
if ignore_label is not None:
ce_kwargs['ignore_index'] = ignore_label
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.ignore_label = ignore_label
self.ce = RobustCrossEntropyLoss(**ce_kwargs)
self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = target != self.ignore_label
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.where(mask, target, 0)
num_fg = mask.sum()
else:
target_dice = target
mask = None
dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target[:, 0]) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result
class DC_and_BCE_loss(nn.Module):
def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False,
dice_class=MemoryEfficientSoftDiceLoss):
"""
DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
target mut be one hot encoded
IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!!
:param soft_dice_kwargs:
:param bce_kwargs:
:param aggregate:
"""
super(DC_and_BCE_loss, self).__init__()
if use_ignore_label:
bce_kwargs['reduction'] = 'none'
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.use_ignore_label = use_ignore_label
self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
if self.use_ignore_label:
# target is one hot encoded here. invert it so that it is True wherever we can compute the loss
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = (1 - target[:, -1:]).bool()
# remove ignore channel now that we have the mask
# why did we use clone in the past? Should have documented that...
# target_regions = torch.clone(target[:, :-1])
target_regions = target[:, :-1]
else:
target_regions = target
mask = None
dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
target_regions = target_regions.float()
if mask is not None:
ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)
else:
ce_loss = self.ce(net_output, target_regions)
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result
class DC_and_topk_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param aggregate:
:param square_dice:
:param weight_ce:
:param weight_dice:
"""
super().__init__()
if ignore_label is not None:
ce_kwargs['ignore_index'] = ignore_label
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.ignore_label = ignore_label
self.ce = TopKLoss(**ce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = (target != self.ignore_label).bool()
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.clone(target)
target_dice[target == self.ignore_label] = 0
num_fg = mask.sum()
else:
target_dice = target
mask = None
dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result