diff --git a/nnunetv2/model_sharing/entry_points.py b/nnunetv2/model_sharing/entry_points.py index 8d356adc1..076ff15ac 100644 --- a/nnunetv2/model_sharing/entry_points.py +++ b/nnunetv2/model_sharing/entry_points.py @@ -173,12 +173,12 @@ def export_pretrained_model_onnx_entry(): print("!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!") print("######################################################") print( - "You are responsible for creating the ONNX pipeline \n" + "You are responsible for creating the ONNX pipeline\n" "yourself.\n\n" - "This script will only export the model \n" - "weights to an onnx file, and some basic information \n" - "about the model. You will have to create the ONNX \n" - "pipeline yourself. \n" + "This script will only export the model weights to\n" + "an onnx file, and some basic information about\n" + "the model. You will have to create the ONNX pipeline\n" + "yourself.\n" ) print( "See\n" diff --git a/nnunetv2/model_sharing/onnx_export.py b/nnunetv2/model_sharing/onnx_export.py index eab2a307a..14ea95454 100644 --- a/nnunetv2/model_sharing/onnx_export.py +++ b/nnunetv2/model_sharing/onnx_export.py @@ -73,10 +73,7 @@ def export_onnx_model( config = predictor.configuration_manager for fold, params in zip(folds, list_of_parameters): - if not isinstance(network, OptimizedModule): - network.load_state_dict(params) - else: - network._orig_mod.load_state_dict(params) + network.load_state_dict(params) network.eval() @@ -115,15 +112,21 @@ def export_onnx_model( ort_inputs = {ort_session.get_inputs()[0].name: rand_input.numpy()} ort_outs = ort_session.run(None, ort_inputs) - np.testing.assert_allclose( - torch_output.detach().cpu().numpy(), - ort_outs[0], - rtol=1e-03, - atol=1e-05, - ) + try: + np.testing.assert_allclose( + torch_output.detach().cpu().numpy(), + ort_outs[0], + rtol=1e-03, + atol=1e-05, + verbose=True, + ) + except AssertionError as e: + print(f"WARN: Differences found between torch and onnx:\n") + print(e) + print("\nExport will continue, but please verify that your pipeline matches the original.") print( - f"Successfully exported and verified {curr_output_dir / output_name}" + f"Exported {curr_output_dir / output_name}" ) with open(curr_output_dir / "config.json", "w") as f: