-
Notifications
You must be signed in to change notification settings - Fork 13
/
eps.py
63 lines (47 loc) · 2.26 KB
/
eps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch.nn import functional as F
def get_eps_loss(cam, saliency, num_classes, label, tau, lam, intermediate=True):
"""
Get EPS loss for pseudo-pixel supervision from saliency map.
Args:
cam (tensor): response from model with float values.
saliency (tensor): saliency map from off-the-shelf saliency model.
num_classes (int): the number of classes
label (tensor): label information.
tau (float): threshold for confidence area
lam (float): blending ratio between foreground map and background map
intermediate (bool): if True return all the intermediates, if not return only loss.
Shape:
cam (N, C, H', W') where N is the batch size and C is the number of classes.
saliency (N, 1, H, W)
label (N, C)
"""
b, c, h, w = cam.size()
saliency = F.interpolate(saliency, size=(h, w))
label_map = label.view(b, num_classes, 1, 1).expand(size=(b, num_classes, h, w)).bool()
# Map selection
label_map_fg = torch.zeros(size=(b, num_classes + 1, h, w)).bool().cuda()
label_map_bg = torch.zeros(size=(b, num_classes + 1, h, w)).bool().cuda()
label_map_bg[:, num_classes] = True
label_map_fg[:, :-1] = label_map.clone()
sal_pred = F.softmax(cam, dim=1)
iou_saliency = (torch.round(sal_pred[:, :-1].detach()) * torch.round(saliency)).view(b, num_classes, -1).sum(-1) / \
(torch.round(sal_pred[:, :-1].detach()) + 1e-04).view(b, num_classes, -1).sum(-1)
valid_channel = (iou_saliency > tau).view(b, num_classes, 1, 1).expand(size=(b, num_classes, h, w))
label_fg_valid = label_map & valid_channel
label_map_fg[:, :-1] = label_fg_valid
label_map_bg[:, :-1] = label_map & (~valid_channel)
# Saliency loss
fg_map = torch.zeros_like(sal_pred).cuda()
bg_map = torch.zeros_like(sal_pred).cuda()
fg_map[label_map_fg] = sal_pred[label_map_fg]
bg_map[label_map_bg] = sal_pred[label_map_bg]
fg_map = torch.sum(fg_map, dim=1, keepdim=True)
bg_map = torch.sum(bg_map, dim=1, keepdim=True)
bg_map = torch.sub(1, bg_map)
sal_pred = fg_map * lam + bg_map * (1 - lam)
loss = F.mse_loss(sal_pred, saliency)
if intermediate:
return loss, fg_map, bg_map, sal_pred
else:
return loss