From 35d35bd3da52273cdd8fd2a8300b4598b2d96cc7 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 19 Dec 2024 14:14:56 +0400 Subject: [PATCH] Move check_dummy_inputs_allowed to common export utils (#2114) * move check_dummy_inputs_allowed to common export utils * move decoder_merge import * Update optimum/exporters/utils.py * Update optimum/exporters/utils.py * avoid onnx import if not necessary * move merge decoders import * fix style * add comment --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Ella Charlaix --- optimum/exporters/onnx/base.py | 12 ++++++---- optimum/exporters/onnx/config.py | 6 ++++- optimum/exporters/onnx/convert.py | 29 ++++--------------------- optimum/exporters/onnx/model_configs.py | 6 ++++- optimum/exporters/utils.py | 27 ++++++++++++++++++++++- 5 files changed, 48 insertions(+), 32 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..b5adb4522a2 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -27,16 +27,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np -import onnx from transformers.utils import is_accelerate_available, is_torch_available -from ...onnx import remove_duplicate_weights_from_tied_info - if is_torch_available(): import torch.nn as nn -from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, DummyInputGenerator, @@ -54,6 +50,8 @@ from .model_patcher import ModelPatcher, Seq2SeqModelPatcher +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + if is_accelerate_available(): from accelerate.utils import find_tied_parameters @@ -542,6 +540,10 @@ def post_process_exported_models( first_key = next(iter(models_and_onnx_configs)) if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module): if is_accelerate_available(): + import onnx + + from ...onnx import remove_duplicate_weights_from_tied_info + logger.info("Deduplicating shared (tied) weights...") for subpath, key in zip(onnx_files_subpaths, models_and_onnx_configs): torch_model = models_and_onnx_configs[key][0] @@ -934,6 +936,8 @@ def post_process_exported_models( decoder_with_past_path = Path(path, onnx_files_subpaths[2]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: + from ...onnx import merge_decoders + # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9e808e392b9..69366d6be13 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -20,7 +20,6 @@ from transformers.utils import is_tf_available -from ...onnx import merge_decoders from ...utils import ( DummyAudioInputGenerator, DummyBboxInputGenerator, @@ -38,6 +37,9 @@ from .model_patcher import DecoderModelPatcher +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + + if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel @@ -129,6 +131,8 @@ def post_process_exported_models( # Attempt to merge only if the decoder-only was exported separately without/with past if self.use_past is True and len(models_and_onnx_configs) == 2: + from ...onnx import merge_decoders + decoder_path = Path(path, onnx_files_subpaths[0]) decoder_with_past_path = Path(path, onnx_files_subpaths[1]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index c12a9ac222a..80d945580c7 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -22,7 +22,7 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx @@ -45,6 +45,7 @@ from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError from ..tasks import TasksManager +from ..utils import check_dummy_inputs_are_allowed from .base import OnnxConfig from .constants import UNPICKABLE_ARCHS from .model_configs import SpeechT5OnnxConfig @@ -56,6 +57,8 @@ ) +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + if is_torch_available(): import torch import torch.nn as nn @@ -75,30 +78,6 @@ class DynamicAxisNameError(ValueError): pass -def check_dummy_inputs_are_allowed( - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] -): - """ - Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. - Args: - model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): - The model instance. - model_inputs (`Iterable[str]`): - The model input names. - """ - - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call - forward_parameters = signature(forward).parameters - forward_inputs_set = set(forward_parameters.keys()) - dummy_input_names = set(dummy_input_names) - - # We are fine if config_inputs has more keys than model_inputs - if not dummy_input_names.issubset(forward_inputs_set): - raise ValueError( - f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" - ) - - def validate_models_outputs( models_and_onnx_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4c5a727a183..3a48a579c2c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -21,7 +21,6 @@ from packaging import version from transformers.utils import is_tf_available -from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, BloomDummyPastKeyValuesGenerator, @@ -94,6 +93,9 @@ ) +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + + if TYPE_CHECKING: from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel @@ -2018,6 +2020,8 @@ def post_process_exported_models( decoder_with_past_path = Path(path, onnx_files_subpaths[3]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: + from ...onnx import merge_decoders + # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 60de169de5e..d4a4111075d 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -16,7 +16,8 @@ """Utilities for model preparation to export.""" import copy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from inspect import signature +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs( export_config = next(iter(models_and_export_configs.values()))[1] return export_config, models_and_export_configs + + +def check_dummy_inputs_are_allowed( + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] +): + """ + Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. + Args: + model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): + The model instance. + model_inputs (`Iterable[str]`): + The model input names. + """ + + forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call + forward_parameters = signature(forward).parameters + forward_inputs_set = set(forward_parameters.keys()) + dummy_input_names = set(dummy_input_names) + + # We are fine if config_inputs has more keys than model_inputs + if not dummy_input_names.issubset(forward_inputs_set): + raise ValueError( + f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" + )