From 5779bac4c45b2c881603cafd20663892869d5860 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 25 Oct 2024 09:44:09 +0200 Subject: [PATCH] Fix onnx non-expotable inplace aten op (#34376) * fix onnx non-expotable inplace op * mistral, qwen2, qwen2_vl, starcoder2 * fixup copies --- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 4 ++-- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 514f9de706ec63..cbdd2c663c5844 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1156,7 +1156,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f198e4abc85511..321d3dc0daf378 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -961,7 +961,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 192b7801af0575..78a17178ecdda8 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1174,7 +1174,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 97200b7d042e61..9975996d21d144 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1385,7 +1385,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: @@ -1689,7 +1689,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index a1a86e3672d5fc..bae3f6d4cdaeaa 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1136,7 +1136,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 791f6df50bb40f..f3690e5f686fbb 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1305,7 +1305,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 0d97f2ffb724a0..0883fac1aebafc 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1059,7 +1059,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 36de586265ce60..7f4f19aba1f3eb 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1239,7 +1239,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4e9401c77e4d7d..90bf29c8b5d66a 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1321,7 +1321,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c8f22dee43fe2c..1a8b6412e738e1 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1033,7 +1033,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) - diagonal_attend_mask |= sliding_attend_mask + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: