-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
76 lines (63 loc) · 2.74 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
from utils.util import *
from trainers.eval import meta_test
def test_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--train_way", type=int, default=5)
parser.add_argument("--train_shot", type=int, default=1)
parser.add_argument("--train_query_shot", type=int, default=1)
parser.add_argument("--gpu_num", help="gpu device", type=int, default=1)
parser.add_argument("--resnet", action="store_true")
parser.add_argument("--model", choices=['Proto', 'FRN'])
parser.add_argument("--dataset", choices=['cub_cropped', 'cub_raw',
'aircraft',
'meta_iNat', 'tiered_meta_iNat',
'stanford_car', 'stanford_dog'])
parser.add_argument("--TDM", action="store_true")
args = parser.parse_args()
return args
args = test_parser()
assert args.gpu_num > 0, "TDM is only tested with GPU setting"
test_path = dataset_path(args)
test_path = os.path.join(test_path, 'test_pre')
save_path = get_save_path(args)
args.save_path = save_path
logger_path = os.path.join(args.save_path, 'test.log')
if os.path.isfile(logger_path):
file = open(logger_path, 'r')
lines = file.read().splitlines()
file.close()
logger = get_logger(logger_path)
for i in range(len(lines)):
logger.info(lines[i][17:])
else:
logger = get_logger(logger_path)
model = load_pretrained_model(args)
model.eval()
model.cuda()
if args.gpu_num > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.gpu_num)))
with torch.no_grad():
if args.model == 'FRN':
for shot in [1, 5]:
pre = True
transform_type = None
mean, interval = meta_test(data_path=test_path,
model=model,
way=args.train_way,
shot=shot,
pre=pre,
transform_type=transform_type,
gpu_num=args.gpu_num)
logger.info('%d-way-%d-shot acc: %.3f\t%.3f' % (args.train_way, shot, mean, interval))
else:
pre = True
transform_type = None
mean, interval = meta_test(data_path=test_path,
model=model,
way=args.train_way,
shot=args.train_shot,
pre=pre,
transform_type=transform_type,
gpu_num=args.gpu_num)
logger.info('%d-way-%d-shot acc: %.3f\t%.3f' % (args.train_way, args.train_shot, mean, interval))