forked from kevinzakka/pytorch-goodies
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
138 lines (106 loc) · 3.69 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Common image segmentation metrics.
"""
import torch
from utils import nanmean
EPS = 1e-10
def _fast_hist(true, pred, num_classes):
mask = (true >= 0) & (true < num_classes)
hist = torch.bincount(
num_classes * true[mask] + pred[mask],
minlength=num_classes ** 2,
).reshape(num_classes, num_classes).float()
return hist
def overall_pixel_accuracy(hist):
"""Computes the total pixel accuracy.
The overall pixel accuracy provides an intuitive
approximation for the qualitative perception of the
label when it is viewed in its overall shape but not
its details.
Args:
hist: confusion matrix.
Returns:
overall_acc: the overall pixel accuracy.
"""
correct = torch.diag(hist).sum()
total = hist.sum()
overall_acc = correct / (total + EPS)
return overall_acc
def per_class_pixel_accuracy(hist):
"""Computes the average per-class pixel accuracy.
The per-class pixel accuracy is a more fine-grained
version of the overall pixel accuracy. A model could
score a relatively high overall pixel accuracy by
correctly predicting the dominant labels or areas
in the image whilst incorrectly predicting the
possibly more important/rare labels. Such a model
will score a low per-class pixel accuracy.
Args:
hist: confusion matrix.
Returns:
avg_per_class_acc: the average per-class pixel accuracy.
"""
correct_per_class = torch.diag(hist)
total_per_class = hist.sum(dim=1)
per_class_acc = correct_per_class / (total_per_class + EPS)
avg_per_class_acc = nanmean(per_class_acc)
return avg_per_class_acc
def jaccard_index(hist):
"""Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
Args:
hist: confusion matrix.
Returns:
avg_jacc: the average per-class jaccard index.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim=1)
B = hist.sum(dim=0)
jaccard = A_inter_B / (A + B - A_inter_B + EPS)
avg_jacc = nanmean(jaccard)
return avg_jacc
def dice_coefficient(hist):
"""Computes the Sørensen–Dice coefficient, a.k.a the F1 score.
Args:
hist: confusion matrix.
Returns:
avg_dice: the average per-class dice coefficient.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim=1)
B = hist.sum(dim=0)
dice = (2 * A_inter_B) / (A + B + EPS)
avg_dice = nanmean(dice)
return avg_dice
def eval_metrics(true, pred, num_classes):
"""Computes various segmentation metrics on 2D feature maps.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
pred: a tensor of shape [B, H, W] or [B, 1, H, W].
num_classes: the number of classes to segment. This number
should be less than the ID of the ignored class.
Returns:
overall_acc: the overall pixel accuracy.
avg_per_class_acc: the average per-class pixel accuracy.
avg_jacc: the jaccard index.
avg_dice: the dice coefficient.
"""
hist = torch.zeros((num_classes, num_classes))
for t, p in zip(true, pred):
hist += _fast_hist(t.flatten(), p.flatten(), num_classes)
overall_acc = overall_pixel_accuracy(hist)
avg_per_class_acc = per_class_pixel_accuracy(hist)
avg_jacc = jaccard_index(hist)
avg_dice = dice_coefficient(hist)
return overall_acc, avg_per_class_acc, avg_jacc, avg_dice
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count