Skip to content

Commit

Permalink
Fix onnx non-expotable inplace aten op (#34376)
Browse files Browse the repository at this point in the history
* fix onnx non-expotable inplace op

* mistral, qwen2, qwen2_vl, starcoder2

* fixup copies
  • Loading branch information
IlyasMoutawwakil authored and ArthurZucker committed Oct 29, 2024
1 parent dcfe3c7 commit 565f0e9
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down Expand Up @@ -1689,7 +1689,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
Expand Down

0 comments on commit 565f0e9

Please sign in to comment.