-
Notifications
You must be signed in to change notification settings - Fork 43
/
predict.py
52 lines (45 loc) · 1.32 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding:UTF-8 -*-
"""
predict
@Cai Yichao 2020_09_23
"""
import torch
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from models.resnext import *
from PIL import Image
import time
import argparse
from utils.arg_utils import *
parser = argparse.ArgumentParser()
parser.add_argument('--file', '-f', type=str)
parse_args = parser.parse_args()
args = fetch_args()
classes = ['infrared', 'normal']
num_classes = len(classes)
ckpt_file = args['ckpt_path']+'/ckpt_3_acc100.00.pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([transforms.Resize(args['input_size']),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
"""
loading model
"""
net = resNeXt50_32x4d_SE(num_classes=num_classes)
ckpt = torch.load(ckpt_file)
# if device is 'cuda':
# net = torch.nn.DataParallel(net)
# cudnn.benchmark = True
net.to(device)
net.load_state_dict(ckpt['net'])
net.eval()
start_time = time.time()
image = Image.open(parse_args.file)
image = transform(image)
image = image.unsqueeze(0)
image = image.to(device)
with torch.no_grad():
out = net(image)
_, predict = out.max(1)
print(classes[predict[0]])
print("cost time: %.2f"%(time.time()-start_time))