forked from bamps53/SeesawLoss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
seesaw_loss.py
89 lines (71 loc) · 3.15 KB
/
seesaw_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
import torch.nn as nn
from typing import Union
class SeesawLossWithLogits(nn.Module):
"""
This is unofficial implementation for Seesaw loss,
which is proposed in the techinical report for LVIS workshop at ECCV 2020.
For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
Args:
class_counts: The list which has number of samples for each class.
Should have same length as num_classes.
p: Scale parameter which adjust the strength of panishment.
Set to 0.8 as a default by following the original paper.
"""
def __init__(self, class_counts: Union[list, np.array], p: float = 0.8):
super().__init__()
class_counts = torch.FloatTensor(class_counts)
conditions = class_counts[:, None] > class_counts[None, :]
trues = (class_counts[None, :] / class_counts[:, None]) ** p
falses = torch.ones(len(class_counts), len(class_counts))
self.s = torch.where(conditions, trues, falses)
self.eps = 1.0e-6
def forward(self, logits, targets):
self.s = self.s.to(targets.device)
max_element, _ = logits.max(axis=-1)
logits = logits - max_element[:, None] # to prevent overflow
numerator = torch.exp(logits)
denominator = (
(1 - targets)[:, None, :]
* self.s[None, :, :]
* torch.exp(logits)[:, None, :]).sum(axis=-1) \
+ torch.exp(logits)
sigma = numerator / (denominator + self.eps)
loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
return loss.mean()
class DistibutionAgnosticSeesawLossWithLogits(nn.Module):
"""
This is unofficial implementation for Seesaw loss,
which is proposed in the techinical report for LVIS workshop at ECCV 2020.
For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
Args:
p: Scale parameter which adjust the strength of panishment.
Set to 0.8 for default following the paper.
"""
def __init__(self, p: float = 0.8):
super().__init__()
self.eps = 1.0e-6
self.p = p
self.s = None
self.class_counts = None
def forward(self, logits, targets):
if self.class_counts is None:
self.class_counts = targets.sum(axis=0) + 1 # to prevent devided by zero.
else:
self.class_counts += targets.sum(axis=0)
conditions = self.class_counts[:, None] > self.class_counts[None, :]
trues = (self.class_counts[None, :] / self.class_counts[:, None]) ** self.p
falses = torch.ones(len(self.class_counts), len(self.class_counts))
self.s = torch.where(conditions, trues, falses)
max_element, _ = logits.max(axis=-1)
logits = logits - max_element[:, None] # to prevent overflow
numerator = torch.exp(logits)
denominator = (
(1 - targets)[:, None, :]
* self.s[None, :, :]
* torch.exp(logits)[:, None, :]).sum(axis=-1) \
+ torch.exp(logits)
sigma = numerator / (denominator + self.eps)
loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
return loss.mean()