generated from yakhyo/project-template
-
Notifications
You must be signed in to change notification settings - Fork 4
/
losses.py
122 lines (102 loc) · 3.36 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
from typing import Optional
from crackseg.utils.functional import cross_entropy, dice_loss, sigmoid_focal_loss
__all__ = ["DiceLoss", "DiceCELoss", "CrossEntropyLoss", "FocalLoss"]
class CrossEntropyLoss(nn.Module):
"""Cross Entropy Loss"""
def __init__(
self,
class_weights: Optional[torch.Tensor] = None,
reduction: str = "mean",
loss_weight: float = 1.0,
) -> None:
super().__init__()
self.class_weight = class_weights
self.reduction = reduction
self.loss_weight = loss_weight
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
weight: Optional[torch.Tensor] = None,
ignore_index: int = -100,
) -> torch.Tensor:
loss = self.loss_weight * cross_entropy(
inputs, targets, weight, class_weight=self.class_weight, reduction=self.reduction, ignore_index=ignore_index
)
return loss
class DiceLoss(nn.Module):
def __init__(
self,
reduction: str = "mean",
loss_weight: Optional[float] = 1.0,
eps: float = 1e-5,
) -> None:
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.eps = eps
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
weight: Optional[torch.Tensor] = None,
) -> torch.Tensor:
loss = self.loss_weight * dice_loss(inputs, targets, weight=weight, reduction=self.reduction, eps=self.eps)
return loss
class DiceCELoss(nn.Module):
def __init__(
self,
reduction: str = "mean",
dice_weight: float = 1.0,
ce_weight: float = 1.0,
eps: float = 1e-5,
) -> None:
super().__init__()
self.reduction = reduction
self.dice_weight = dice_weight
self.ce_weight = ce_weight
self.eps = eps
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
weight: Optional[torch.Tensor] = None
) -> torch.Tensor:
# calculate dice loss
dice = dice_loss(inputs, targets, weight=weight, reduction=self.reduction, eps=self.eps)
# calculate cross entropy loss
ce = cross_entropy(inputs, targets, weight=weight, reduction=self.reduction)
# accumulate loss according to given weights
loss = self.dice_weight * dice + ce * self.ce_weight
return loss
class FocalLoss(nn.Module):
"""Sigmoid Focal Loss"""
def __init__(
self,
gamma: float = 2.0,
alpha: float = 0.25,
reduction: str = "mean",
loss_weight: float = 1.0
) -> None:
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
weight: Optional[torch.Tensor] = None,
) -> torch.Tensor:
loss = self.loss_weight * sigmoid_focal_loss(
inputs,
targets,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=self.reduction
)
return loss