-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
101 lines (80 loc) · 3.64 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import importlib
from model.metrics import *
from model.label_shift_est import LSC
from model.model_init import model_init
from utilis.test_ft import test_ft
from dataloader.Custom_Dataloader import FeatureDataset
model_paths = {
'stage2': 'networks.stage_2',
'none_cifar': 'networks.resnet_cifar',
'none': 'networks.resnet',
'tail_cifar': 'networks.resnet_cifar_ensemble',
'tail': 'networks.resnet_ensemble'
}
def get_metrics(probs, labels, cls_num_list):
labels = [tensor.cpu().item() for tensor in labels]
acc = acc_cal(probs, labels, method='top1')
mmf_acc = list(mmf_acc_cal(probs, labels, cls_num_list))
logging.info('Many Medium Few shot Top1 Acc: ' + str(mmf_acc))
print('Many Medium Few shot Top1 Acc: ' + str(mmf_acc))
return acc, mmf_acc
# Read from main.py directly: test_set, dset_info, dataset_info, args
def evaluation(test_set, dset_info, dataset_info, args, cfg):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the model for evaluate
model = test_ft(datapath=dataset_info["path"],
args=args,
modelpath=args.eval,
crt_modelpath=None, test_cfg=None)
model.to(device)
model.eval()
loss_fn = nn.CrossEntropyLoss()
#### --------------- evaluate ---------------
# Get number of batches
num_batches = len(test_set)
test_loss, correct, total = 0, 0, 0
probs, labels = [], []
# Since we dont need to update the gradients, we use torch.no_grad()
with torch.no_grad():
for data in test_set:
# Every data instance is an image + label pair
img, label = data
# Transfer data to target device
img = img.to(device)
label = label.to(device)
labels.append(label)
# Compute prediction for this batch
logit = model(img)
# Compute the loss
test_loss += loss_fn(logit, label).item()
# Calculate the index of maximum logit as the predicted label
prob = F.softmax(logit, dim=1)
probs.extend(list(prob.squeeze().cpu().numpy()))
pred = prob.argmax(dim=1)
# Record correct predictions
correct += (pred == label).type(torch.float).sum().item()
total += label.size(0)
# -----------------Post Compensation Accuracy-------------------------------#
probs = np.array(probs)
labels = torch.cat(labels)
_, mmf_acc = get_metrics(probs, labels, dset_info['per_class_img_num'])
# Gather data and report
test_loss /= num_batches
accuracy = correct / total
logging.info("Test Error: Accuracy: {:.2f}, Avg loss: {:.4f} ".format(100 * accuracy, test_loss))
print("Test Error: Accuracy: {:.2f}, Avg loss: {:.4f} ".format(100 * accuracy, test_loss))
pc_probs = LSC(probs, cls_num_list=dset_info['per_class_img_num'])
label_shift_acc, mmf_acc_pc = get_metrics(pc_probs, labels, dset_info['per_class_img_num'])
logging.info("Test Error: Accuracy: {:.2f}, Avg loss: {:.4f} ".format(100 * accuracy, test_loss))
print("Test Error: Accuracy: {:.2f}, Avg loss: {:.4f} ".format(100 * accuracy, test_loss))
logging.info("Label Shift Accracy is: {}".format(label_shift_acc))
print("Label Shift Accracy is:", label_shift_acc)
logging.info("\n")
print("\n\n")
return test_loss, accuracy, label_shift_acc, mmf_acc, mmf_acc_pc #FIXME