Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gpt bigcode ONNX export for transformers<4.39.0 #1973

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


class Qwen2OnnxConfig(LlamaOnnxConfig):
pass
MIN_TRANSFORMERS_VERSION = version.parse("4.37.0")


class GemmaOnnxConfig(LlamaOnnxConfig):
Expand All @@ -291,6 +291,7 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")


class Phi3OnnxConfig(PhiOnnxConfig):
Expand All @@ -299,6 +300,7 @@ class Phi3OnnxConfig(PhiOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down Expand Up @@ -1173,7 +1175,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:


class OwlV2OnnxConfig(OwlViTOnnxConfig):
pass
MIN_TRANSFORMERS_VERSION = version.parse("4.35.0")


class LayoutLMOnnxConfig(TextAndVisionOnnxConfig):
Expand Down
15 changes: 11 additions & 4 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,16 @@ def __init__(
model.decoder.model.decoder.config.use_cache = True


def _unmask_unattended_patched(
expanded_mask: torch.Tensor,
min_dtype: float,
def _unmask_unattended_patched_legacy(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
):
return expanded_mask


def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float):
return expanded_mask


def _make_causal_mask_patched(
input_ids_shape: torch.Size,
dtype: torch.dtype,
Expand Down Expand Up @@ -316,7 +319,11 @@ def _make_causal_mask_patched(


_make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched)
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched)

if _transformers_version >= version.parse("4.39.0"):
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched)
else:
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy)


# Adapted from _prepare_4d_causal_attention_mask
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
REQUIRED_PKGS = [
"coloredlogs",
"sympy",
"transformers[sentencepiece]>=4.26.0,<4.43.0",
"transformers[sentencepiece]>=4.33.0,<4.43.0",
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
"torch>=1.11",
"packaging",
"numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569
Expand Down
Loading