diff --git a/hawp/detector.py b/hawp/detector.py index 4f633ee..55fa2a4 100644 --- a/hawp/detector.py +++ b/hawp/detector.py @@ -120,14 +120,18 @@ def forward(self, images, annotations = None): #jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) #topK = torch.clamp((jloc_pred_nms > 0.008).count_nonzero(), max=300) - juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=300, th=0.008) + nms_jloc_pred = non_maximum_suppression(jloc_pred)[0] + juncs_pred, _ = get_junctions(nms_jloc_pred,joff_pred[0], topk=300, th=0.008) extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] extra_info['time_matching'] = time.time() dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:,2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0) - idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) - idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + idx_junc_to_end_stacked = torch.stack((idx_junc_to_end1, idx_junc_to_end2)) + idx_junc_to_end_min = idx_junc_to_end_stacked.min(dim=0)[0] + idx_junc_to_end_max = idx_junc_to_end_stacked.max(dim=0)[0] iskeep = (idx_junc_to_end_min < idx_junc_to_end_max)# * (dis_junc_to_end1< 10*10)*(dis_junc_to_end2<10*10) # *(dis_junc_to_end2<100)