diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b39d19ec782..12cf9c6108b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -21,7 +21,6 @@ from packaging import version from transformers.utils import is_tf_available -from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, BloomDummyPastKeyValuesGenerator, @@ -1875,6 +1874,7 @@ def post_process_exported_models( decoder_with_past_path = Path(path, onnx_files_subpaths[3]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: + from ...onnx import merge_decoders # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders(