From 21f709c6785a39412f9791ad799b11dacef1d670 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 25 Jul 2024 16:48:38 +0200 Subject: [PATCH] fix gpt bigcode ONNX export for transformers<=4.36.0 --- .github/workflows/test_export_onnx_cli.yml | 2 ++ optimum/exporters/onnx/model_patcher.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_export_onnx_cli.yml b/.github/workflows/test_export_onnx_cli.yml index 8fa4ebb045f..2380688a8f1 100644 --- a/.github/workflows/test_export_onnx_cli.yml +++ b/.github/workflows/test_export_onnx_cli.yml @@ -17,6 +17,7 @@ jobs: matrix: python-version: [3.8, 3.9] os: [ubuntu-20.04] + transformers-version: ["4.26.0", "4.42.*"] runs-on: ${{ matrix.os }} steps: @@ -27,6 +28,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies for pytorch export run: | + pip install transformers==${{ matrix.transformers-version }} pip install .[tests,exporters] - name: Test with unittest working-directory: tests diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 1f873e4e716..af3c222e866 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -276,12 +276,13 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _unmask_unattended_patched( - expanded_mask: torch.Tensor, - min_dtype: float, +def _unmask_unattended_patched_legacy( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] ): return expanded_mask +def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): + return expanded_mask def _make_causal_mask_patched( input_ids_shape: torch.Size, @@ -316,7 +317,11 @@ def _make_causal_mask_patched( _make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched) -_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) + +if _transformers_version >= version.parse("4.39.0"): + _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) +else: + _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy) # Adapted from _prepare_4d_causal_attention_mask