diff --git a/hawp/export.py b/hawp/export.py index 10bce3c..361d68f 100644 --- a/hawp/export.py +++ b/hawp/export.py @@ -1,3 +1,5 @@ +import argparse + import numpy as np import onnx import onnxruntime as ort @@ -7,44 +9,69 @@ from .predicting import WireframeParser -INPUT_FILE = "/home/ckanesan/Data/semantic-keypoints/826840/Image_000001.jpg" -OUTPUT_FILE = "hawp.onnx" +def cli(): + parser = argparse.ArgumentParser( + prog="python -m hawp.export", + usage="%(prog)s [options] image", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("image", help="input image") + parser.add_argument( + "-o", "--output", default="hawp.onnx", nargs="?", help="Path at which to write the exported model" + ) + + return parser.parse_args() + def to_numpy(tensor): + if isinstance(tensor, int): + return tensor return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -def check_model(): - model = onnx.load(OUTPUT_FILE) + +def check_model(model_file): + model = onnx.load(model_file) onnx.checker.check_model(model) - inferred_model = onnx.shape_inference.infer_shapes(model) - onnx.save(inferred_model, "hawp_inf.onnx") -def verify_model(input, output): - ort_session = ort.InferenceSession(OUTPUT_FILE, providers=["CUDAExecutionProvider"]) - ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)} +def verify_model(model_file, input, output): + ort_session = ort.InferenceSession(model_file, providers=["CUDAExecutionProvider"]) + ort_inputs = {"image": to_numpy(input)} ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results - np.testing.assert_allclose(to_numpy(output), ort_outs[0], rtol=1e-03, atol=1e-05) + for obs, exp in zip(ort_outs, output): + np.testing.assert_allclose(to_numpy(exp), obs, rtol=1e-03, atol=1e-05) -def export(cfg): +def export(input_file, output_file): wireframe_parser = WireframeParser(export_mode=True) - for _ in wireframe_parser.images([INPUT_FILE]): + for _ in wireframe_parser.images([input_file]): pass model = wireframe_parser.model [input] = wireframe_parser.inputs [output] = wireframe_parser.outputs - torch.onnx.export(model, input, OUTPUT_FILE, opset_version=11, ) + output_names = [ + "vertices", + "v_confidences", + "edges", + "edge_weights", + "frame_width", + "frame_height", + ] + torch.onnx.export( + model, input, output_file, opset_version=11, input_names=["image"], output_names=output_names + ) return input, output if __name__ == "__main__": cfg.freeze() - input, output = export(cfg) - check_model() - #verify_model(input, output) + args = cli() + input, output = export(args.image, args.output) + check_model(args.output) + img, _ = input + verify_model(args.output, img, output)