Skip to content

Commit

Permalink
Enable ONNX export for transformers 4.45
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 9, 2024
1 parent 049b00f commit 39b940e
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np
import onnx
import transformers
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available

Expand Down Expand Up @@ -531,6 +530,10 @@ def export_pytorch(
logger.info(f"Using framework PyTorch: {torch.__version__}")
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"

model_kwargs = model_kwargs or {}
if check_if_transformers_greater("4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys():
model_kwargs["num_logits_to_keep"] = 0

with torch.no_grad():
model.config.return_dict = True
model = model.eval()
Expand Down Expand Up @@ -1001,11 +1004,6 @@ def onnx_export_from_model(
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
if check_if_transformers_greater("4.44.99"):
raise ImportError(
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}"
)

TasksManager.standardize_model_attributes(model)

if hasattr(model.config, "export_model_type"):
Expand Down

0 comments on commit 39b940e

Please sign in to comment.