From 2d2ae698cd724dcdb8557f717bcb1bf5195c4556 Mon Sep 17 00:00:00 2001 From: Christian Kanesan Date: Thu, 6 Jan 2022 16:41:06 +0100 Subject: [PATCH 1/4] Add export module --- hawp/detector.py | 27 ++++++++++++++----------- hawp/export.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ hawp/predicting.py | 17 +++++++++++++--- 3 files changed, 79 insertions(+), 15 deletions(-) create mode 100644 hawp/export.py diff --git a/hawp/detector.py b/hawp/detector.py index c2d7d9c..4f633ee 100644 --- a/hawp/detector.py +++ b/hawp/detector.py @@ -12,7 +12,7 @@ } def non_maximum_suppression(a): - ap = F.max_pool2d(a, 3, stride=1, padding=1) + ap = F.max_pool2d(a, 3, stride=(1,1), padding=(1,1)) mask = (a == ap).float().clamp(min=0.0) return a * mask @@ -21,7 +21,10 @@ def get_junctions(jloc, joff, topk = 300, th=0): jloc = jloc.reshape(-1) joff = joff.reshape(2, -1) - scores, index = torch.topk(jloc, k=topk) + #scores, index = torch.topk(jloc, k=topk) + jloc_sorted, jloc_index = jloc.sort(descending=True) + scores = jloc_sorted[:topk] + index = jloc_index[:topk] y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 @@ -61,6 +64,7 @@ def __init__(self, cfg): nn.Linear(self.dim_fc, 1), ) self.train_step = 0 + self.export_mode = False def pooling(self, features_per_image, lines_per_im): h,w = features_per_image.size(1), features_per_image.size(2) @@ -107,20 +111,16 @@ def forward(self, images, annotations = None): joff_pred= output[:,7:9].sigmoid() - 0.5 extra_info['time_backbone'] = time.time() - extra_info['time_backbone'] - - batch_size = md_pred.size(0) - assert batch_size == 1 - extra_info['time_proposal'] = time.time() if self.use_residual: lines_pred = self.proposal_lines_new(md_pred[0],dis_pred[0],res_pred[0]).view(-1,4) else: lines_pred = self.proposal_lines_new(md_pred[0], dis_pred[0], None).view(-1, 4) - jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) - topK = min(300, int((jloc_pred_nms>0.008).float().sum().item())) + #jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) + #topK = torch.clamp((jloc_pred_nms > 0.008).count_nonzero(), max=300) - juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=topK) + juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=300, th=0.008) extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] extra_info['time_matching'] = time.time() dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) @@ -155,10 +155,13 @@ def forward(self, images, annotations = None): v2e_idx1 = v2e_dis1.argmin(dim=0) v2e_idx2 = v2e_dis2.argmin(dim=0) edge_indices = torch.stack((v2e_idx1,v2e_idx2)).t() - wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) - wireframe.rescale(annotations[0]['width'],annotations[0]['height']) - + if not self.export_mode: + wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) + wireframe.rescale(annotations[0]['width'],annotations[0]['height']) + else: + return juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2) + extra_info['time_verification'] = time.time() - extra_info['time_verification'] return wireframe, extra_info diff --git a/hawp/export.py b/hawp/export.py new file mode 100644 index 0000000..10bce3c --- /dev/null +++ b/hawp/export.py @@ -0,0 +1,50 @@ +import numpy as np +import onnx +import onnxruntime as ort +import torch.onnx + +from .config import cfg +from .predicting import WireframeParser + + +INPUT_FILE = "/home/ckanesan/Data/semantic-keypoints/826840/Image_000001.jpg" +OUTPUT_FILE = "hawp.onnx" + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +def check_model(): + model = onnx.load(OUTPUT_FILE) + onnx.checker.check_model(model) + + inferred_model = onnx.shape_inference.infer_shapes(model) + onnx.save(inferred_model, "hawp_inf.onnx") + +def verify_model(input, output): + ort_session = ort.InferenceSession(OUTPUT_FILE, providers=["CUDAExecutionProvider"]) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-03, atol=1e-05) + + +def export(cfg): + wireframe_parser = WireframeParser(export_mode=True) + for _ in wireframe_parser.images([INPUT_FILE]): + pass + + model = wireframe_parser.model + [input] = wireframe_parser.inputs + [output] = wireframe_parser.outputs + torch.onnx.export(model, input, OUTPUT_FILE, opset_version=11, ) + + return input, output + + +if __name__ == "__main__": + cfg.freeze() + input, output = export(cfg) + check_model() + #verify_model(input, output) + diff --git a/hawp/predicting.py b/hawp/predicting.py index 6f6684e..edb087f 100644 --- a/hawp/predicting.py +++ b/hawp/predicting.py @@ -2,6 +2,7 @@ import torch from hawp import show from hawp.config import cfg +from hawp.graph import WireframeGraph from hawp.utils.comm import to_device from hawp.dataset.build import build_transform @@ -18,14 +19,18 @@ class WireframeParser(object): loader_workers = None device = 'cuda' - def __init__(self, json_data = False, + def __init__(self, json_data = False, visualize_image = False, - visualize_processed_image = False): + visualize_processed_image = False, + export_mode = False): self.model = get_hawp_model(pretrained=True).eval() self.model = self.model.to(self.device) + self.model.export_mode = export_mode self.preprocessor_transform = build_transform(cfg) self.visualize_image = visualize_image self.visualize_processed_image = visualize_processed_image + self.inputs = [] + self.outputs = [] def dataset(self, data): loader_workers = self.loader_workers @@ -51,7 +56,13 @@ def dataloader(self, dataloader): visualizer.Base.image(image_batch[0]) processed_image_batch = processed_image_batch.to(self.device) with torch.no_grad(): - wireframe, _ = self.model(processed_image_batch, meta_batch) + if self.model.export_mode: + self.inputs.append((processed_image_batch, meta_batch)) + output = self.model(processed_image_batch, meta_batch) + self.outputs.append(output) + wireframe = WireframeGraph(*output) + else: + wireframe, _ = self.model(processed_image_batch, meta_batch) yield wireframe, gt_anns_batch[0], meta_batch[0] From 6c6bb27d81c44e50358efb4521dce5f65534be29 Mon Sep 17 00:00:00 2001 From: flahoud Date: Fri, 7 Jan 2022 10:29:29 +0100 Subject: [PATCH 2/4] Update detector --- hawp/detector.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/hawp/detector.py b/hawp/detector.py index 4f633ee..55fa2a4 100644 --- a/hawp/detector.py +++ b/hawp/detector.py @@ -120,14 +120,18 @@ def forward(self, images, annotations = None): #jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) #topK = torch.clamp((jloc_pred_nms > 0.008).count_nonzero(), max=300) - juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=300, th=0.008) + nms_jloc_pred = non_maximum_suppression(jloc_pred)[0] + juncs_pred, _ = get_junctions(nms_jloc_pred,joff_pred[0], topk=300, th=0.008) extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] extra_info['time_matching'] = time.time() dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:,2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0) - idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) - idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) + # idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + idx_junc_to_end_stacked = torch.stack((idx_junc_to_end1, idx_junc_to_end2)) + idx_junc_to_end_min = idx_junc_to_end_stacked.min(dim=0)[0] + idx_junc_to_end_max = idx_junc_to_end_stacked.max(dim=0)[0] iskeep = (idx_junc_to_end_min < idx_junc_to_end_max)# * (dis_junc_to_end1< 10*10)*(dis_junc_to_end2<10*10) # *(dis_junc_to_end2<100) From e3c54c677b7a7d67be29c18a7b7dc01f0ac9bb5d Mon Sep 17 00:00:00 2001 From: Christian Kanesan Date: Fri, 7 Jan 2022 15:21:15 +0100 Subject: [PATCH 3/4] Parse arguments --- hawp/export.py | 59 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/hawp/export.py b/hawp/export.py index 10bce3c..361d68f 100644 --- a/hawp/export.py +++ b/hawp/export.py @@ -1,3 +1,5 @@ +import argparse + import numpy as np import onnx import onnxruntime as ort @@ -7,44 +9,69 @@ from .predicting import WireframeParser -INPUT_FILE = "/home/ckanesan/Data/semantic-keypoints/826840/Image_000001.jpg" -OUTPUT_FILE = "hawp.onnx" +def cli(): + parser = argparse.ArgumentParser( + prog="python -m hawp.export", + usage="%(prog)s [options] image", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("image", help="input image") + parser.add_argument( + "-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model" + ) + + return parser.parse_args() + def to_numpy(tensor): + if isinstance(tensor, int): + return tensor return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -def check_model(): - model = onnx.load(OUTPUT_FILE) + +def check_model(model_file): + model = onnx.load(model_file) onnx.checker.check_model(model) - inferred_model = onnx.shape_inference.infer_shapes(model) - onnx.save(inferred_model, "hawp_inf.onnx") -def verify_model(input, output): - ort_session = ort.InferenceSession(OUTPUT_FILE, providers=["CUDAExecutionProvider"]) - ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)} +def verify_model(model_file, input, output): + ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"]) + ort_inputs = {"image": to_numpy(input)} ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results - np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-03, atol=1e-05) + for obs, exp in zip(ort_outs, output): + np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05) -def export(cfg): +def export(input_file, output_file): wireframe_parser = WireframeParser(export_mode=True) - for _ in wireframe_parser.images([INPUT_FILE]): + for _ in wireframe_parser.images([input_file]): pass model = wireframe_parser.model [input] = wireframe_parser.inputs [output] = wireframe_parser.outputs - torch.onnx.export(model, input, OUTPUT_FILE, opset_version=11, ) + output_names = [ + "vertices", + "v_confidences", + "edges", + "edge_weights", + "frame_width", + "frame_height", + ] + torch.onnx.export( + model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names + ) return input, output if __name__ == "__main__": cfg.freeze() - input, output = export(cfg) - check_model() - #verify_model(input, output) + args = cli() + input, output = export(args.image, args.output) + check_model(args.output) + img, _ = input + verify_model(args.output, img, output) From 73954b062844ee341b23f2fe120c48f5623b7109 Mon Sep 17 00:00:00 2001 From: Christian Kanesan Date: Fri, 7 Jan 2022 15:27:49 +0100 Subject: [PATCH 4/4] Minor fixes --- hawp/detector.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hawp/detector.py b/hawp/detector.py index 55fa2a4..de37b09 100644 --- a/hawp/detector.py +++ b/hawp/detector.py @@ -12,7 +12,7 @@ } def non_maximum_suppression(a): - ap = F.max_pool2d(a, 3, stride=(1,1), padding=(1,1)) + ap = F.max_pool2d(a, 3, stride=1, padding=1) mask = (a == ap).float().clamp(min=0.0) return a * mask @@ -117,11 +117,8 @@ def forward(self, images, annotations = None): else: lines_pred = self.proposal_lines_new(md_pred[0], dis_pred[0], None).view(-1, 4) - #jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) - #topK = torch.clamp((jloc_pred_nms > 0.008).count_nonzero(), max=300) - nms_jloc_pred = non_maximum_suppression(jloc_pred)[0] - juncs_pred, _ = get_junctions(nms_jloc_pred,joff_pred[0], topk=300, th=0.008) + juncs_pred, _ = get_junctions(nms_jloc_pred, joff_pred[0], topk=300, th=0.008) extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] extra_info['time_matching'] = time.time() dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) @@ -196,7 +193,7 @@ def proposal_lines(self, md_maps, dis_maps, scale=5.0): cs_ed = torch.cos(ed_).clamp(min=1e-3) ss_ed = torch.sin(ed_).clamp(max=-1e-3) - x_standard = torch.ones_like(cs_st) + #x_standard = torch.ones_like(cs_st) y_st = ss_st/cs_st y_ed = ss_ed/cs_ed