Skip to content

Commit

Permalink
Parse arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
ckanesan committed Jan 7, 2022
1 parent 6c6bb27 commit e3c54c6
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions hawp/export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import numpy as np
import onnx
import onnxruntime as ort
Expand All @@ -7,44 +9,69 @@
from .predicting import WireframeParser


INPUT_FILE = "/home/ckanesan/Data/semantic-keypoints/826840/Image_000001.jpg"
OUTPUT_FILE = "hawp.onnx"
def cli():
parser = argparse.ArgumentParser(
prog="python -m hawp.export",
usage="%(prog)s [options] image",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("image", help="input image")
parser.add_argument(
"-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model"
)

return parser.parse_args()


def to_numpy(tensor):
if isinstance(tensor, int):
return tensor
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def check_model():
model = onnx.load(OUTPUT_FILE)

def check_model(model_file):
model = onnx.load(model_file)
onnx.checker.check_model(model)

inferred_model = onnx.shape_inference.infer_shapes(model)
onnx.save(inferred_model, "hawp_inf.onnx")

def verify_model(input, output):
ort_session = ort.InferenceSession(OUTPUT_FILE, providers=["CUDAExecutionProvider"])
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
def verify_model(model_file, input, output):
ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"])
ort_inputs = {"image": to_numpy(input)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-03, atol=1e-05)
for obs, exp in zip(ort_outs, output):
np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05)


def export(cfg):
def export(input_file, output_file):
wireframe_parser = WireframeParser(export_mode=True)
for _ in wireframe_parser.images([INPUT_FILE]):
for _ in wireframe_parser.images([input_file]):
pass

model = wireframe_parser.model
[input] = wireframe_parser.inputs
[output] = wireframe_parser.outputs
torch.onnx.export(model, input, OUTPUT_FILE, opset_version=11, )
output_names = [
"vertices",
"v_confidences",
"edges",
"edge_weights",
"frame_width",
"frame_height",
]
torch.onnx.export(
model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names
)

return input, output


if __name__ == "__main__":
cfg.freeze()
input, output = export(cfg)
check_model()
#verify_model(input, output)
args = cli()

input, output = export(args.image, args.output)
check_model(args.output)
img, _ = input
verify_model(args.output, img, output)

0 comments on commit e3c54c6

Please sign in to comment.