From 7680b3e3b3775f97ff98fa5d6d2d6976e384a2c7 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 4 Dec 2024 09:25:29 +0400 Subject: [PATCH] move check_dummy_inputs_allowed to common export utils --- optimum/exporters/onnx/convert.py | 27 ++------------------------- optimum/exporters/utils.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index c12a9ac222a..0d4c544cd3a 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -22,7 +22,7 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx @@ -45,6 +45,7 @@ from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError from ..tasks import TasksManager +from ..utils import check_dummy_inputs_are_allowed from .base import OnnxConfig from .constants import UNPICKABLE_ARCHS from .model_configs import SpeechT5OnnxConfig @@ -75,30 +76,6 @@ class DynamicAxisNameError(ValueError): pass -def check_dummy_inputs_are_allowed( - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] -): - """ - Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. - Args: - model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): - The model instance. - model_inputs (`Iterable[str]`): - The model input names. - """ - - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call - forward_parameters = signature(forward).parameters - forward_inputs_set = set(forward_parameters.keys()) - dummy_input_names = set(dummy_input_names) - - # We are fine if config_inputs has more keys than model_inputs - if not dummy_input_names.issubset(forward_inputs_set): - raise ValueError( - f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" - ) - - def validate_models_outputs( models_and_onnx_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 60de169de5e..59e053ee444 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -16,7 +16,8 @@ """Utilities for model preparation to export.""" import copy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from inspect import signature +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs( export_config = next(iter(models_and_export_configs.values()))[1] return export_config, models_and_export_configs + + +def check_dummy_inputs_are_allowed( + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] +): + """ + Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. + Args: + model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): + The model instance. + model_inputs (`Iterable[str]`): + The model input names. + """ + + forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call + forward_parameters = signature(forward).parameters + forward_inputs_set = set(forward_parameters.keys()) + dummy_input_names = set(dummy_input_names) + + # We are fine if config_inputs has more keys than model_inputs + if not dummy_input_names.issubset(forward_inputs_set): + raise ValueError( + f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" + ) \ No newline at end of file