-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add export module #1
base: inference
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,10 @@ def get_junctions(jloc, joff, topk = 300, th=0): | |
jloc = jloc.reshape(-1) | ||
joff = joff.reshape(2, -1) | ||
|
||
scores, index = torch.topk(jloc, k=topk) | ||
#scores, index = torch.topk(jloc, k=topk) | ||
jloc_sorted, jloc_index = jloc.sort(descending=True) | ||
scores = jloc_sorted[:topk] | ||
index = jloc_index[:topk] | ||
y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 | ||
x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 | ||
|
||
|
@@ -61,6 +64,7 @@ def __init__(self, cfg): | |
nn.Linear(self.dim_fc, 1), | ||
) | ||
self.train_step = 0 | ||
self.export_mode = False | ||
|
||
def pooling(self, features_per_image, lines_per_im): | ||
h,w = features_per_image.size(1), features_per_image.size(2) | ||
|
@@ -107,27 +111,24 @@ def forward(self, images, annotations = None): | |
joff_pred= output[:,7:9].sigmoid() - 0.5 | ||
extra_info['time_backbone'] = time.time() - extra_info['time_backbone'] | ||
|
||
|
||
batch_size = md_pred.size(0) | ||
assert batch_size == 1 | ||
|
||
extra_info['time_proposal'] = time.time() | ||
if self.use_residual: | ||
lines_pred = self.proposal_lines_new(md_pred[0],dis_pred[0],res_pred[0]).view(-1,4) | ||
else: | ||
lines_pred = self.proposal_lines_new(md_pred[0], dis_pred[0], None).view(-1, 4) | ||
|
||
jloc_pred_nms = non_maximum_suppression(jloc_pred[0]) | ||
topK = min(300, int((jloc_pred_nms>0.008).float().sum().item())) | ||
|
||
juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[0]),joff_pred[0], topk=topK) | ||
nms_jloc_pred = non_maximum_suppression(jloc_pred)[0] | ||
juncs_pred, _ = get_junctions(nms_jloc_pred, joff_pred[0], topk=300, th=0.008) | ||
extra_info['time_proposal'] = time.time() - extra_info['time_proposal'] | ||
extra_info['time_matching'] = time.time() | ||
dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0) | ||
dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:,2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0) | ||
|
||
idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) | ||
idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) | ||
# idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) | ||
# idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) | ||
Comment on lines
+127
to
+128
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.min and torch.max signatures are (tensor, dim, keep_dims, ..) The fix is to stack the two tensors together, then take the min/max along the stacked dimension. |
||
idx_junc_to_end_stacked = torch.stack((idx_junc_to_end1, idx_junc_to_end2)) | ||
idx_junc_to_end_min = idx_junc_to_end_stacked.min(dim=0)[0] | ||
idx_junc_to_end_max = idx_junc_to_end_stacked.max(dim=0)[0] | ||
|
||
iskeep = (idx_junc_to_end_min < idx_junc_to_end_max)# * (dis_junc_to_end1< 10*10)*(dis_junc_to_end2<10*10) # *(dis_junc_to_end2<100) | ||
|
||
|
@@ -155,10 +156,13 @@ def forward(self, images, annotations = None): | |
v2e_idx1 = v2e_dis1.argmin(dim=0) | ||
v2e_idx2 = v2e_dis2.argmin(dim=0) | ||
edge_indices = torch.stack((v2e_idx1,v2e_idx2)).t() | ||
wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) | ||
|
||
wireframe.rescale(annotations[0]['width'],annotations[0]['height']) | ||
|
||
if not self.export_mode: | ||
wireframe = WireframeGraph(juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2)) | ||
wireframe.rescale(annotations[0]['width'],annotations[0]['height']) | ||
else: | ||
return juncs_final,juncs_score,edge_indices,score_final,output.size(3),output.size(2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. output.size(3) and output.size(2) are not tensors anymore, they probably won't have meaningful data once exported |
||
|
||
extra_info['time_verification'] = time.time() - extra_info['time_verification'] | ||
|
||
return wireframe, extra_info | ||
|
@@ -189,7 +193,7 @@ def proposal_lines(self, md_maps, dis_maps, scale=5.0): | |
cs_ed = torch.cos(ed_).clamp(min=1e-3) | ||
ss_ed = torch.sin(ed_).clamp(max=-1e-3) | ||
|
||
x_standard = torch.ones_like(cs_st) | ||
#x_standard = torch.ones_like(cs_st) | ||
|
||
y_st = ss_st/cs_st | ||
y_ed = ss_ed/cs_ed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import argparse | ||
|
||
import numpy as np | ||
import onnx | ||
import onnxruntime as ort | ||
import torch.onnx | ||
|
||
from .config import cfg | ||
from .predicting import WireframeParser | ||
|
||
|
||
def cli(): | ||
parser = argparse.ArgumentParser( | ||
prog="python -m hawp.export", | ||
usage="%(prog)s [options] image", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
) | ||
parser.add_argument("image", help="input image") | ||
parser.add_argument( | ||
"-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model" | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def to_numpy(tensor): | ||
if isinstance(tensor, int): | ||
return tensor | ||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | ||
|
||
|
||
def check_model(model_file): | ||
model = onnx.load(model_file) | ||
onnx.checker.check_model(model) | ||
|
||
|
||
def verify_model(model_file, input, output): | ||
ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"]) | ||
ort_inputs = {"image": to_numpy(input)} | ||
ort_outs = ort_session.run(None, ort_inputs) | ||
|
||
# compare ONNX Runtime and PyTorch results | ||
for obs, exp in zip(ort_outs, output): | ||
np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05) | ||
|
||
|
||
def export(input_file, output_file): | ||
wireframe_parser = WireframeParser(export_mode=True) | ||
for _ in wireframe_parser.images([input_file]): | ||
pass | ||
|
||
model = wireframe_parser.model | ||
[input] = wireframe_parser.inputs | ||
[output] = wireframe_parser.outputs | ||
output_names = [ | ||
"vertices", | ||
"v_confidences", | ||
"edges", | ||
"edge_weights", | ||
"frame_width", | ||
"frame_height", | ||
] | ||
torch.onnx.export( | ||
model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names | ||
) | ||
|
||
return input, output | ||
|
||
|
||
if __name__ == "__main__": | ||
cfg.freeze() | ||
args = cli() | ||
|
||
input, output = export(args.image, args.output) | ||
check_model(args.output) | ||
img, _ = input | ||
verify_model(args.output, img, output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the part where I pass the 4 dimensional tensor, and then take out the batch dim after max_pool2d