Skip to content

Commit

Permalink
Switch from dynamo export to regular export
Browse files Browse the repository at this point in the history
Dynamo seems to be a bit buggy still
  • Loading branch information
Mathijs de Boer committed Dec 1, 2023
1 parent 438e633 commit 979a716
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
8 changes: 8 additions & 0 deletions nnunetv2/model_sharing/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ def export_pretrained_model_onnx_entry():
required=False,
help="Set this to export the cross-validation predictions as well",
)
parser.add_argument(
"-v",
action="store_false",
default=False,
required=False,
help="Set this to get verbose output",
)
args = parser.parse_args()

print("######################################################")
Expand Down Expand Up @@ -191,4 +198,5 @@ def export_pretrained_model_onnx_entry():
folds=args.f,
strict=not args.not_strict,
save_checkpoints=args.chk,
verbose=args.v,
)
50 changes: 41 additions & 9 deletions nnunetv2/model_sharing/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from pathlib import Path
from typing import Tuple, Union

import numpy as np
import onnx
import onnxruntime
import torch
from torch._dynamo import OptimizedModule

Expand All @@ -27,6 +30,7 @@ def export_onnx_model(
strict: bool = True,
save_checkpoints: Tuple[str, ...] = ("checkpoint_final.pth",),
output_names: tuple[str, ...] = None,
verbose: bool = False,
) -> None:
if not output_names:
output_names = (f"{checkpoint[:-4]}.onnx" for checkpoint in save_checkpoints)
Expand Down Expand Up @@ -71,14 +75,6 @@ def export_onnx_model(

network.eval()

export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
rand_input = torch.rand((1, 1, *config.patch_size))
traced_model = torch.onnx.dynamo_export(
network,
rand_input,
export_options=export_options,
)

curr_output_dir = output_dir / c / f"fold_{fold}"
if not curr_output_dir.exists():
curr_output_dir.mkdir(parents=True)
Expand All @@ -88,7 +84,43 @@ def export_onnx_model(
f"Output directory {curr_output_dir} is not empty"
)

traced_model.save(str(curr_output_dir / output_name))
rand_input = torch.rand((1, 1, *config.patch_size))
torch_output = network(rand_input)

torch.onnx.export(
network,
rand_input,
curr_output_dir / output_name,
export_params=True,
verbose=verbose,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
},
)

onnx_model = onnx.load(curr_output_dir / output_name)
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession(
curr_output_dir / output_name, providers=["CPUExecutionProvider"]
)
ort_inputs = {ort_session.get_inputs()[0].name: rand_input.numpy()}
ort_outs = ort_session.run(None, ort_inputs)

np.testing.assert_allclose(
torch_output.detach().cpu().numpy(),
ort_outs[0],
rtol=1e-03,
atol=1e-05,
)

print(
f"Successfully exported and verified {curr_output_dir / output_name}"
)

with open(curr_output_dir / "config.json", "w") as f:
json.dump(
{
Expand Down

0 comments on commit 979a716

Please sign in to comment.