-
Notifications
You must be signed in to change notification settings - Fork 2
/
metrics.py
60 lines (50 loc) · 1.89 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from sklearn.metrics import confusion_matrix
import numpy as np
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class MetricComputer(object):
def __init__(self, metric_names=None):
__all_metrics__ = {"accuracy": self._accuracy, "class_wise_acc": self._class_level_accuracy}
all_names = list(__all_metrics__.keys())
if metric_names is None:
metric_names = all_names
for n in metric_names: assert n in all_names
self.metrics = {m: __all_metrics__[m] for m in metric_names}
# self.n_classes = n_classes
def __call__(self, out, target):
"""
Args:
out (torch.Tensor): Model output
target (torch.Tensor): Target labels
"""
# pred = out.argmax(dim=1)
pred=out
result = {m: self.metrics[m](out, pred, target) for m in self.metrics.keys()}
return result
def _accuracy(self, out, pred, target):
acc = (pred == target).float().detach().mean()
return acc.item()
def _class_level_accuracy(self, out, pred, target):
per_class_acc = {}
for c in range(self.n_classes):
count = (target == c).sum().detach().item()
if count == 0:
continue
class_true = ((pred == target) * (target == c)).float().sum().item()
per_class_acc[c] = (class_true, count)
return per_class_acc
def _confusion_matrix(self, out, pred, target):
y_true = target.detach().cpu()
y_pred = pred.detach().cpu()
return confusion_matrix(y_true, y_pred, normalize=None, labels=np.arange(self.n_classes))