Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add export module #1

Draft
wants to merge 4 commits into
base: inference
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions hawp/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def get_junctions(jloc, joff, topk = 300, th=0):
jloc = jloc.reshape(-1)
joff = joff.reshape(2, -1)

scores, index = torch.topk(jloc, k=topk)
#scores, index = torch.topk(jloc, k=topk)
jloc_sorted, jloc_index = jloc.sort(descending=True)
scores = jloc_sorted[:topk]
index = jloc_index[:topk]
y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5
x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5

Expand Down Expand Up @@ -61,6 +64,7 @@ def __init__(self, cfg):
nn.Linear(self.dim_fc, 1),
)
self.train_step = 0
self.export_mode = False

def pooling(self, features_per_image, lines_per_im):
h,w = features_per_image.size(1), features_per_image.size(2)
Expand Down Expand Up @@ -107,27 +111,24 @@ def forward(self, images, annotations = None):
joff_pred= output[:,7:9].sigmoid() - 0.5
extra_info['time_backbone'] = time.time() - extra_info['time_backbone']


batch_size = md_pred.size(0)
assert batch_size == 1

extra_info['time_proposal'] = time.time()
if self.use_residual:
lines_pred = self.proposal_lines_new(md_pred[0],dis_pred[0],res_pred[0]).view(-1,4)
else:
lines_pred = self.proposal_lines_new(md_pred[0], dis_pred[0], None).view(-1, 4)

jloc_pred_nms = non_maximum_suppression(jloc_pred[0])
topK = min(300, int((jloc_pred_nms>0.008).float().sum().item()))

juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=topK)
nms_jloc_pred = non_maximum_suppression(jloc_pred)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part where I pass the 4 dimensional tensor, and then take out the batch dim after max_pool2d

juncs_pred, _ = get_junctions(nms_jloc_pred, joff_pred[0], topk=300, th=0.008)
extra_info['time_proposal'] = time.time() - extra_info['time_proposal']
extra_info['time_matching'] = time.time()
dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0)
dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:,2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0)

idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2)
idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2)
# idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2)
# idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2)
Comment on lines +127 to +128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.min and torch.max signatures are (tensor, dim, keep_dims, ..)
They are used here are torch.min/max(tensor1, tensor2) to do element-wise minimum/maximum operations.
This does not work for onnx, as it sees min and max, and expects the second argument to be an integer, and not a tensor.

The fix is to stack the two tensors together, then take the min/max along the stacked dimension.

idx_junc_to_end_stacked = torch.stack((idx_junc_to_end1, idx_junc_to_end2))
idx_junc_to_end_min = idx_junc_to_end_stacked.min(dim=0)[0]
idx_junc_to_end_max = idx_junc_to_end_stacked.max(dim=0)[0]

iskeep = (idx_junc_to_end_min < idx_junc_to_end_max)# * (dis_junc_to_end1< 10*10)*(dis_junc_to_end2<10*10) # *(dis_junc_to_end2<100)

Expand Down Expand Up @@ -155,10 +156,13 @@ def forward(self, images, annotations = None):
v2e_idx1 = v2e_dis1.argmin(dim=0)
v2e_idx2 = v2e_dis2.argmin(dim=0)
edge_indices = torch.stack((v2e_idx1,v2e_idx2)).t()
wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2))

wireframe.rescale(annotations[0]['width'],annotations[0]['height'])

if not self.export_mode:
wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2))
wireframe.rescale(annotations[0]['width'],annotations[0]['height'])
else:
return juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output.size(3) and output.size(2) are not tensors anymore, they probably won't have meaningful data once exported


extra_info['time_verification'] = time.time() - extra_info['time_verification']

return wireframe, extra_info
Expand Down Expand Up @@ -189,7 +193,7 @@ def proposal_lines(self, md_maps, dis_maps, scale=5.0):
cs_ed = torch.cos(ed_).clamp(min=1e-3)
ss_ed = torch.sin(ed_).clamp(max=-1e-3)

x_standard = torch.ones_like(cs_st)
#x_standard = torch.ones_like(cs_st)

y_st = ss_st/cs_st
y_ed = ss_ed/cs_ed
Expand Down
77 changes: 77 additions & 0 deletions hawp/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse

import numpy as np
import onnx
import onnxruntime as ort
import torch.onnx

from .config import cfg
from .predicting import WireframeParser


def cli():
parser = argparse.ArgumentParser(
prog="python -m hawp.export",
usage="%(prog)s [options] image",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("image", help="input image")
parser.add_argument(
"-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model"
)

return parser.parse_args()


def to_numpy(tensor):
if isinstance(tensor, int):
return tensor
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def check_model(model_file):
model = onnx.load(model_file)
onnx.checker.check_model(model)


def verify_model(model_file, input, output):
ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"])
ort_inputs = {"image": to_numpy(input)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
for obs, exp in zip(ort_outs, output):
np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05)


def export(input_file, output_file):
wireframe_parser = WireframeParser(export_mode=True)
for _ in wireframe_parser.images([input_file]):
pass

model = wireframe_parser.model
[input] = wireframe_parser.inputs
[output] = wireframe_parser.outputs
output_names = [
"vertices",
"v_confidences",
"edges",
"edge_weights",
"frame_width",
"frame_height",
]
torch.onnx.export(
model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names
)

return input, output


if __name__ == "__main__":
cfg.freeze()
args = cli()

input, output = export(args.image, args.output)
check_model(args.output)
img, _ = input
verify_model(args.output, img, output)
17 changes: 14 additions & 3 deletions hawp/predicting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from hawp import show
from hawp.config import cfg
from hawp.graph import WireframeGraph
from hawp.utils.comm import to_device
from hawp.dataset.build import build_transform

Expand All @@ -18,14 +19,18 @@
class WireframeParser(object):
loader_workers = None
device = 'cuda'
def __init__(self, json_data = False,
def __init__(self, json_data = False,
visualize_image = False,
visualize_processed_image = False):
visualize_processed_image = False,
export_mode = False):
self.model = get_hawp_model(pretrained=True).eval()
self.model = self.model.to(self.device)
self.model.export_mode = export_mode
self.preprocessor_transform = build_transform(cfg)
self.visualize_image = visualize_image
self.visualize_processed_image = visualize_processed_image
self.inputs = []
self.outputs = []

def dataset(self, data):
loader_workers = self.loader_workers
Expand All @@ -51,7 +56,13 @@ def dataloader(self, dataloader):
visualizer.Base.image(image_batch[0])
processed_image_batch = processed_image_batch.to(self.device)
with torch.no_grad():
wireframe, _ = self.model(processed_image_batch, meta_batch)
if self.model.export_mode:
self.inputs.append((processed_image_batch, meta_batch))
output = self.model(processed_image_batch, meta_batch)
self.outputs.append(output)
wireframe = WireframeGraph(*output)
else:
wireframe, _ = self.model(processed_image_batch, meta_batch)

yield wireframe, gt_anns_batch[0], meta_batch[0]

Expand Down