Skip to content

Commit

Permalink
fix predict bug
Browse files Browse the repository at this point in the history
  • Loading branch information
WongGawa committed Jan 10, 2025
1 parent bc7509f commit e195db4
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_parser_infer(parents=None):
parser.add_argument(
"--single_cls", type=ast.literal_eval, default=False, help="train multi-class data as single-class"
)
parser.add_argument("--exec_nms", type=ast.literal_eval, default=True, help="whether to execute NMS or not")
parser.add_argument("--nms_time_limit", type=float, default=60.0, help="time limit for NMS")
parser.add_argument("--conf_thres", type=float, default=0.25, help="object confidence threshold")
parser.add_argument("--iou_thres", type=float, default=0.65, help="IOU threshold for NMS")
Expand Down Expand Up @@ -94,6 +95,7 @@ def detect(
conf_thres: float = 0.25,
iou_thres: float = 0.65,
conf_free: bool = False,
exec_nms: bool = True,
nms_time_limit: float = 60.0,
img_size: int = 640,
stride: int = 32,
Expand Down Expand Up @@ -129,14 +131,15 @@ def detect(
# Run NMS
t = time.time()
out = out.asnumpy()
out = non_max_suppression(
out,
conf_thres=conf_thres,
iou_thres=iou_thres,
conf_free=conf_free,
multi_label=True,
time_limit=nms_time_limit,
)
if exec_nms:
out = non_max_suppression(
out,
conf_thres=conf_thres,
iou_thres=iou_thres,
conf_free=conf_free,
multi_label=True,
time_limit=nms_time_limit,
)
nms_times = time.time() - t

result_dict = {"category_id": [], "bbox": [], "score": []}
Expand Down Expand Up @@ -305,6 +308,7 @@ def infer(args):
conf_thres=args.conf_thres,
iou_thres=args.iou_thres,
conf_free=args.conf_free,
exec_nms=args.exec_nms,
nms_time_limit=args.nms_time_limit,
img_size=args.img_size,
stride=max(max(args.network.stride), 32),
Expand Down

0 comments on commit e195db4

Please sign in to comment.