diff --git a/halfsqueezenet/training/halfsqueezenet_objdetect.py b/halfsqueezenet/training/halfsqueezenet_objdetect.py index 4acfa62..684052c 100644 --- a/halfsqueezenet/training/halfsqueezenet_objdetect.py +++ b/halfsqueezenet/training/halfsqueezenet_objdetect.py @@ -477,15 +477,9 @@ def run_image(model, sess_init, image_dir): objdetect = outputs[0] bndboxes = outputs[1] - max_pred = -100 - max_h = -1 - max_w = -1 - for h in range(0, objdetect.shape[1]): - for w in range(0, objdetect.shape[2]): - if objdetect[0, h, w] > max_pred: - max_pred = objdetect[0, h, w] - max_h = h - max_w = w + max_pred = objdetect[0].max() + argmaxs = np.where(objdetect[0] == max_pred) + max_h, max_w, _ = argmaxs sum_labels= 0; bndbox = {} @@ -581,15 +575,10 @@ def run_single_image(model, sess_init, image): objdetect = outputs[0] bndboxes = outputs[1] - max_pred = -100 - max_h = -1 - max_w = -1 - for h in range(0, objdetect.shape[1]): - for w in range(0, objdetect.shape[2]): - if objdetect[0, h, w] > max_pred: - max_pred = objdetect[0, h, w] - max_h = h - max_w = w + max_pred = objdetect[0].max() + argmaxs = np.where(objdetect[0] == max_pred) + max_h, max_w, _ = argmaxs + bndbox2 = {} bndbox2['xmin'] = int( bndboxes[0,max_h,max_w,0] + max_w*grid_size) bndbox2['ymin'] = int( bndboxes[0,max_h,max_w,1] + max_h*grid_size) @@ -751,4 +740,4 @@ def dump_weights(meta, model, output): sys.exit() assert args.weights.endswith('.npy') - run_single_image(Model(), DictRestore(np.load(args.weights, encoding='latin1').item()), args.run) \ No newline at end of file + run_single_image(Model(), DictRestore(np.load(args.weights, encoding='latin1').item()), args.run)