-
Notifications
You must be signed in to change notification settings - Fork 4
/
evaluate.py
127 lines (109 loc) · 5.76 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# TensorBoxPy3 https://github.com/SMH17/TensorBoxPy3
#This script evaluates the model and outputs result marking the detected target.
#Must be used after the train has done enough iterations to reduce the error
#enough to have significant weights in the output folder generated by train.py
#e.g. python3 evaluate.py --weights output/inception_rezoom_2017_01_17_15.20/save.ckpt-130000 --test_boxes data/images/
#will use the train checkpoints saved by train.py in file save.ckpt-130000 located in
#the output folder inception_rezoom_2017_01_17_15.20 after 130000 iterations.
#Note that, the predictions come in a 15x20 grid. For each cell in the grid,
#you get a confidence in [0,1], and 4 numbers, corresponding to center_y, center_x, height, and width.
import tensorflow as tf
import os
import json
import subprocess
from scipy.misc import imread, imresize
from scipy import misc
from train import build_forward
from utils.annolist import AnnotationLib as al
from utils.train_utils import add_rectangles, rescale_boxes
import cv2
import argparse
print("# TensorBoxPy3: evaluating result")
def get_image_dir(args):
weights_iteration = int(args.weights.split('-')[-1])
expname = '_' + args.expname if args.expname else ''
image_dir = '%s/images_%s_%d%s' % (os.path.dirname(args.weights), os.path.basename(args.test_boxes)[:-5], weights_iteration, expname)
return image_dir
def get_results(args, H):
tf.reset_default_graph()
x_in = tf.placeholder(tf.float32, name='x_in', shape=[H['image_height'], H['image_width'], 3])
if H['use_rezoom']:
pred_boxes, pred_logits, pred_confidences, pred_confs_deltas, pred_boxes_deltas = build_forward(H, tf.expand_dims(x_in, 0), 'test', reuse=None)
grid_area = H['grid_height'] * H['grid_width']
pred_confidences = tf.reshape(tf.nn.softmax(tf.reshape(pred_confs_deltas, [grid_area * H['rnn_len'], 2])), [grid_area, H['rnn_len'], 2])
if H['reregress']:
pred_boxes = pred_boxes + pred_boxes_deltas
else:
pred_boxes, pred_logits, pred_confidences = build_forward(H, tf.expand_dims(x_in, 0), 'test', reuse=None)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, args.weights)
pred_annolist = al.AnnoList()
true_annolist = al.parse(args.test_boxes)
data_dir = os.path.dirname(args.test_boxes)
image_dir = get_image_dir(args)
os.makedirs(image_dir, exist_ok=True)
print('Outputs will be stored in {}'.format(image_dir))
for i in range(len(true_annolist)):
try:
true_anno = true_annolist[i]
orig_img = imread('%s/%s' % (data_dir, true_anno.imageName))[:,:,:3]
img = imresize(orig_img, (H["image_height"], H["image_width"]), interp='cubic')
feed = {x_in: img}
(np_pred_boxes, np_pred_confidences) = sess.run([pred_boxes, pred_confidences], feed_dict=feed)
pred_anno = al.Annotation()
pred_anno.imageName = true_anno.imageName
new_img, rects = add_rectangles(H, [img], np_pred_confidences, np_pred_boxes,
use_stitching=True, rnn_len=H['rnn_len'], min_conf=args.min_conf, tau=args.tau, show_suppressed=args.show_suppressed)
rects = [r for r in rects if r.x1<r.x2 and r.y1<r.y2]
pred_anno.rects = rects
pred_anno.imagePath = os.path.abspath(data_dir)
pred_anno = rescale_boxes((H["image_height"], H["image_width"]), pred_anno, orig_img.shape[0], orig_img.shape[1], test=True)
pred_annolist.append(pred_anno)
imname = '%s/%s' % (image_dir, os.path.basename(true_anno.imageName))
misc.imsave(imname, new_img)
except FileNotFoundError:
pass
if i % 25 == 0:
print(i)
return pred_annolist, true_annolist
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', required=True)
parser.add_argument('--expname', default='')
parser.add_argument('--test_boxes', required=True)
parser.add_argument('--gpu', default=0)
parser.add_argument('--logdir', default='output')
parser.add_argument('--iou_threshold', default=0.5, type=float)
parser.add_argument('--tau', default=0.25, type=float)
parser.add_argument('--min_conf', default=0.2, type=float)
parser.add_argument('--show_suppressed', default=False, type=bool)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
hypes_file = '%s/hypes.json' % os.path.dirname(args.weights)
with open(hypes_file, 'r') as f:
H = json.load(f)
expname = args.expname + '_' if args.expname else ''
pred_boxes = '%s.%s%s' % (args.weights, expname, os.path.basename(args.test_boxes))
true_boxes = '%s.gt_%s%s' % (args.weights, expname, os.path.basename(args.test_boxes))
pred_annolist, true_annolist = get_results(args, H)
pred_annolist.save(pred_boxes)
true_annolist.save(true_boxes)
"""
try:
rpc_cmd = './utils/annolist/doRPC.py --minOverlap %f %s %s' % (args.iou_threshold, true_boxes, pred_boxes)
print('$ %s' % rpc_cmd)
rpc_output = subprocess.check_output(rpc_cmd, shell=True)
print(rpc_output)
txt_file = [line for line in rpc_output.split('\n') if line.strip()][-1]
output_png = '%s/results.png' % get_image_dir(args)
plot_cmd = './utils/annolist/plotSimple.py %s --output %s' % (txt_file, output_png)
print('$ %s' % plot_cmd)
plot_output = subprocess.check_output(plot_cmd, shell=True)
print('output results at: %s' % plot_output)
except Exception as e:
print(e)
"""
if __name__ == '__main__':
main()