From ee4aee5dcf3af11756a32a9b2e00fa8bc8132d65 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 25 Jul 2024 18:04:24 +0200 Subject: [PATCH] fix min transformers version --- .github/workflows/test_export_onnx_cli.yml | 2 -- optimum/exporters/onnx/model_configs.py | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_export_onnx_cli.yml b/.github/workflows/test_export_onnx_cli.yml index f4393986133..8fa4ebb045f 100644 --- a/.github/workflows/test_export_onnx_cli.yml +++ b/.github/workflows/test_export_onnx_cli.yml @@ -17,7 +17,6 @@ jobs: matrix: python-version: [3.8, 3.9] os: [ubuntu-20.04] - transformers-version: ["4.33.0", "4.42.*"] runs-on: ${{ matrix.os }} steps: @@ -28,7 +27,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies for pytorch export run: | - pip install transformers==${{ matrix.transformers-version }} pip install .[tests,exporters] - name: Test with unittest working-directory: tests diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e2bcd7fe20d..26202e889b8 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -279,7 +279,7 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): class Qwen2OnnxConfig(LlamaOnnxConfig): - pass + MIN_TRANSFORMERS_VERSION = version.parse("4.37.0") class GemmaOnnxConfig(LlamaOnnxConfig): @@ -291,6 +291,7 @@ class GemmaOnnxConfig(LlamaOnnxConfig): class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + MIN_TRANSFORMERS_VERSION = version.parse("4.36.0") class Phi3OnnxConfig(PhiOnnxConfig): @@ -299,6 +300,7 @@ class Phi3OnnxConfig(PhiOnnxConfig): ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA + MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -1173,7 +1175,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: class OwlV2OnnxConfig(OwlViTOnnxConfig): - pass + MIN_TRANSFORMERS_VERSION = version.parse("4.35.0") class LayoutLMOnnxConfig(TextAndVisionOnnxConfig):