diff --git a/hawp/detector.py b/hawp/detector.py index c2d7d9c..de37b09 100644 --- a/hawp/detector.py +++ b/hawp/detector.py @@ -21,7 +21,10 @@ def get_junctions(jloc, joff, topk = 300, th=0): jloc = jloc.reshape(-1) joff = joff.reshape(2, -1) - scores, index = torch.topk(jloc, k=topk) + #scores, index = torch.topk(jloc, k=topk) + jloc_sorted, jloc_index = jloc.sort(descending=True) + scores = jloc_sorted[:topk] + index = jloc_index[:topk] y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 @@ -61,6 +64,7 @@ def __init__(self, cfg): nn.Linear(self.dim_fc, 1), ) self.train_step = 0 + self.export_mode = False def pooling(self, features_per_image, lines_per_im): h,w = features_per_image.size(1), features_per_image.size(2) @@ -107,27 +111,24 @@ def forward(self, images, annotations = None): joff_pred= output[:,7:9].sigmoid() - 0.5 extra_info['time_backbone'] = time.time() - extra_info['time_backbone'] - - batch_size = md_pred.size(0) - assert batch_size == 1 - extra_info['time_proposal'] = time.time() if self.use_residual: lines_pred = self.proposal_lines_new(md_pred[0],dis_pred[0],res_pred[0]).view(-1,4) else: lines_pred = self.proposal_lines_new(md_pred[0], dis_pred[0], None).view(-1, 4) - jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) - topK = min(300, int((jloc_pred_nms>0.008).float().sum().item())) - - juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=topK) + nms_jloc_pred = non_maximum_suppression(jloc_pred)[0] + juncs_pred, _ = get_junctions(nms_jloc_pred, joff_pred[0], topk=300, th=0.008) extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] extra_info['time_matching'] = time.time() dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:,2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0) - idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) - idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + idx_junc_to_end_stacked = torch.stack((idx_junc_to_end1, idx_junc_to_end2)) + idx_junc_to_end_min = idx_junc_to_end_stacked.min(dim=0)[0] + idx_junc_to_end_max = idx_junc_to_end_stacked.max(dim=0)[0] iskeep = (idx_junc_to_end_min < idx_junc_to_end_max)# * (dis_junc_to_end1< 10*10)*(dis_junc_to_end2<10*10) # *(dis_junc_to_end2<100) @@ -155,10 +156,13 @@ def forward(self, images, annotations = None): v2e_idx1 = v2e_dis1.argmin(dim=0) v2e_idx2 = v2e_dis2.argmin(dim=0) edge_indices = torch.stack((v2e_idx1,v2e_idx2)).t() - wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) - wireframe.rescale(annotations[0]['width'],annotations[0]['height']) - + if not self.export_mode: + wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) + wireframe.rescale(annotations[0]['width'],annotations[0]['height']) + else: + return juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2) + extra_info['time_verification'] = time.time() - extra_info['time_verification'] return wireframe, extra_info @@ -189,7 +193,7 @@ def proposal_lines(self, md_maps, dis_maps, scale=5.0): cs_ed = torch.cos(ed_).clamp(min=1e-3) ss_ed = torch.sin(ed_).clamp(max=-1e-3) - x_standard = torch.ones_like(cs_st) + #x_standard = torch.ones_like(cs_st) y_st = ss_st/cs_st y_ed = ss_ed/cs_ed diff --git a/hawp/export.py b/hawp/export.py new file mode 100644 index 0000000..361d68f --- /dev/null +++ b/hawp/export.py @@ -0,0 +1,77 @@ +import argparse + +import numpy as np +import onnx +import onnxruntime as ort +import torch.onnx + +from .config import cfg +from .predicting import WireframeParser + + +def cli(): + parser = argparse.ArgumentParser( + prog="python -m hawp.export", + usage="%(prog)s [options] image", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("image", help="input image") + parser.add_argument( + "-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model" + ) + + return parser.parse_args() + + +def to_numpy(tensor): + if isinstance(tensor, int): + return tensor + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + +def check_model(model_file): + model = onnx.load(model_file) + onnx.checker.check_model(model) + + +def verify_model(model_file, input, output): + ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"]) + ort_inputs = {"image": to_numpy(input)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + for obs, exp in zip(ort_outs, output): + np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05) + + +def export(input_file, output_file): + wireframe_parser = WireframeParser(export_mode=True) + for _ in wireframe_parser.images([input_file]): + pass + + model = wireframe_parser.model + [input] = wireframe_parser.inputs + [output] = wireframe_parser.outputs + output_names = [ + "vertices", + "v_confidences", + "edges", + "edge_weights", + "frame_width", + "frame_height", + ] + torch.onnx.export( + model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names + ) + + return input, output + + +if __name__ == "__main__": + cfg.freeze() + args = cli() + + input, output = export(args.image, args.output) + check_model(args.output) + img, _ = input + verify_model(args.output, img, output) diff --git a/hawp/predicting.py b/hawp/predicting.py index 6f6684e..edb087f 100644 --- a/hawp/predicting.py +++ b/hawp/predicting.py @@ -2,6 +2,7 @@ import torch from hawp import show from hawp.config import cfg +from hawp.graph import WireframeGraph from hawp.utils.comm import to_device from hawp.dataset.build import build_transform @@ -18,14 +19,18 @@ class WireframeParser(object): loader_workers = None device = 'cuda' - def __init__(self, json_data = False, + def __init__(self, json_data = False, visualize_image = False, - visualize_processed_image = False): + visualize_processed_image = False, + export_mode = False): self.model = get_hawp_model(pretrained=True).eval() self.model = self.model.to(self.device) + self.model.export_mode = export_mode self.preprocessor_transform = build_transform(cfg) self.visualize_image = visualize_image self.visualize_processed_image = visualize_processed_image + self.inputs = [] + self.outputs = [] def dataset(self, data): loader_workers = self.loader_workers @@ -51,7 +56,13 @@ def dataloader(self, dataloader): visualizer.Base.image(image_batch[0]) processed_image_batch = processed_image_batch.to(self.device) with torch.no_grad(): - wireframe, _ = self.model(processed_image_batch, meta_batch) + if self.model.export_mode: + self.inputs.append((processed_image_batch, meta_batch)) + output = self.model(processed_image_batch, meta_batch) + self.outputs.append(output) + wireframe = WireframeGraph(*output) + else: + wireframe, _ = self.model(processed_image_batch, meta_batch) yield wireframe, gt_anns_batch[0], meta_batch[0]