Skip to content

Commit

Permalink
some cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent c5a8a1d commit 2f9661d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
14 changes: 1 addition & 13 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
import inspect
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_torch_available
Expand All @@ -36,18 +36,6 @@
logger = logging.get_logger(__name__)


def get_argument(argument_name: str, args: List[Any], kwargs: Dict[str, Any], forward_signature):
"""
Get the argument argument_name from the args and kwargs according to the signature forward_signature.
"""
args = list(args)
if argument_name in forward_signature.parameters:
argument_index = list(forward_signature.parameters.keys()).index(argument_name)
return args[argument_index]
else:
return kwargs[argument_name]


def override_arguments(args, kwargs, forward_signature, model_kwargs: Dict[str, Any]):
"""
Override the args and kwargs with the argument values from model_kwargs, following the signature forward_signature corresponding to args and kwargs.
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def _infer_task_from_model_name_or_path(
pipeline_tag = getattr(model_info, "pipeline_tag", None)
# conversational is not a supported task per se, just an alias that may map to
# text-generaton or text2text-generation.
if pipeline_tag is not None and pipeline_tag not in ["conversational"]:
if pipeline_tag is not None and pipeline_tag != "conversational":
inferred_task_name = TasksManager.map_from_synonym(model_info.pipeline_tag)
else:
transformers_info = model_info.transformersInfo
Expand Down

0 comments on commit 2f9661d

Please sign in to comment.