Skip to content

Commit

Permalink
move check_dummy_inputs_allowed to common export utils
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 4, 2024
1 parent d6de676 commit 7680b3e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
27 changes: 2 additions & 25 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
Expand All @@ -45,6 +45,7 @@
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from ..utils import check_dummy_inputs_are_allowed
from .base import OnnxConfig
from .constants import UNPICKABLE_ARCHS
from .model_configs import SpeechT5OnnxConfig
Expand Down Expand Up @@ -75,30 +76,6 @@ class DynamicAxisNameError(ValueError):
pass


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)


def validate_models_outputs(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
Expand Down
27 changes: 26 additions & 1 deletion optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Utilities for model preparation to export."""

import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs(
export_config = next(iter(models_and_export_configs.values()))[1]

return export_config, models_and_export_configs


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)

0 comments on commit 7680b3e

Please sign in to comment.