Skip to content

Commit

Permalink
Move check_dummy_inputs_allowed to common export utils (#2114)
Browse files Browse the repository at this point in the history
* move check_dummy_inputs_allowed to common export utils

* move decoder_merge import

* Update optimum/exporters/utils.py

* Update optimum/exporters/utils.py

* avoid onnx import if not necessary

* move merge decoders import

* fix style

* add comment

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent 0c42291 commit 35d35bd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 32 deletions.
12 changes: 8 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import onnx
from transformers.utils import is_accelerate_available, is_torch_available

from ...onnx import remove_duplicate_weights_from_tied_info


if is_torch_available():
import torch.nn as nn

from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
DummyInputGenerator,
Expand All @@ -54,6 +50,8 @@
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization

if is_accelerate_available():
from accelerate.utils import find_tied_parameters

Expand Down Expand Up @@ -542,6 +540,10 @@ def post_process_exported_models(
first_key = next(iter(models_and_onnx_configs))
if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module):
if is_accelerate_available():
import onnx

from ...onnx import remove_duplicate_weights_from_tied_info

logger.info("Deduplicating shared (tied) weights...")
for subpath, key in zip(onnx_files_subpaths, models_and_onnx_configs):
torch_model = models_and_onnx_configs[key][0]
Expand Down Expand Up @@ -934,6 +936,8 @@ def post_process_exported_models(
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from transformers.utils import is_tf_available

from ...onnx import merge_decoders
from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
Expand All @@ -38,6 +37,9 @@
from .model_patcher import DecoderModelPatcher


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel

Expand Down Expand Up @@ -129,6 +131,8 @@ def post_process_exported_models(

# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
from ...onnx import merge_decoders

decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
Expand Down
29 changes: 4 additions & 25 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
Expand All @@ -45,6 +45,7 @@
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from ..utils import check_dummy_inputs_are_allowed
from .base import OnnxConfig
from .constants import UNPICKABLE_ARCHS
from .model_configs import SpeechT5OnnxConfig
Expand All @@ -56,6 +57,8 @@
)


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization

if is_torch_available():
import torch
import torch.nn as nn
Expand All @@ -75,30 +78,6 @@ class DynamicAxisNameError(ValueError):
pass


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)


def validate_models_outputs(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from packaging import version
from transformers.utils import is_tf_available

from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -94,6 +93,9 @@
)


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization


if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -2018,6 +2020,8 @@ def post_process_exported_models(
decoder_with_past_path = Path(path, onnx_files_subpaths[3])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
Expand Down
27 changes: 26 additions & 1 deletion optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Utilities for model preparation to export."""

import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs(
export_config = next(iter(models_and_export_configs.values()))[1]

return export_config, models_and_export_configs


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)

0 comments on commit 35d35bd

Please sign in to comment.