forked from mahyarnajibi/SNIPER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_test.py
69 lines (56 loc) · 2.89 KB
/
main_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# --------------------------------------------------------------
# SNIPER: Efficient Multi-Scale Training
# Licensed under The Apache-2.0 License [see LICENSE for details]
# Inference Module
# by Mahyar Najibi and Bharat Singh
# --------------------------------------------------------------
import init
import matplotlib
matplotlib.use('Agg')
from symbols.faster import *
from configs.faster.default_configs import config, update_config, update_config_from_list
from data_utils.load_data import load_proposal_roidb
import mxnet as mx
import argparse
from train_utils.utils import create_logger, load_param
from inference import imdb_detection_wrapper
from inference import imdb_proposal_extraction_wrapper
import os
def parser():
arg_parser = argparse.ArgumentParser('SNIPER test module')
arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
default='configs/faster/sniper_res101_e2e.yml',type=str)
arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
default='SNIPER', type=str)
arg_parser.add_argument('--vis', dest='vis', help='Whether to visualize the detections',
action='store_true')
arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
default=None, nargs=argparse.REMAINDER)
return arg_parser.parse_args()
def main():
args = parser()
update_config(args.cfg)
if args.set_cfg_list:
update_config_from_list(args.set_cfg_list)
context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]
if not os.path.isdir(config.output_path):
os.mkdir(config.output_path)
# Create roidb
roidb, imdb = load_proposal_roidb(config.dataset.dataset, config.dataset.test_image_set, config.dataset.root_path,
config.dataset.dataset_path,
proposal=config.dataset.proposal, only_gt=True, flip=False,
result_path=config.output_path,
proposal_path=config.proposal_path, get_imdb=True)
# Creating the Logger
logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
print(output_path)
model_prefix = os.path.join(output_path, args.save_prefix)
arg_params, aux_params = load_param(model_prefix, config.TEST.TEST_EPOCH,
convert=True, process=True)
sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))
if config.TEST.EXTRACT_PROPOSALS:
imdb_proposal_extraction_wrapper(sym_inst, config, imdb, roidb, context, arg_params, aux_params, args.vis)
else:
imdb_detection_wrapper(sym_inst, config, imdb, roidb, context, arg_params, aux_params, args.vis)
if __name__ == '__main__':
main()