Skip to content

Commit

Permalink
some cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jul 2, 2024
1 parent 32c2df8 commit 54a9fb0
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/transformers/flash_attention_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -67,10 +68,10 @@ def _flash_attention_forward(
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
is_causal=False,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
softmax_scale: Optional[float] = None,
sliding_window=None,
use_top_left_mask: bool = False,
):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def forward(
q_len,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -246,6 +247,15 @@ class OptFlashAttention2(OPTAttention):
attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -330,7 +340,14 @@ def forward(
value_states = value_states.to(target_dtype)

attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)

attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ def forward(
q_len,
dropout=attn_dropout,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,14 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

sliding_window = self.sliding_window if self.layer_idx >= self.config.max_window_layers else None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None

attn_output = _flash_attention_forward(
query_states,
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,14 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

sliding_window = self.sliding_window if self.layer_idx >= self.config.max_window_layers else None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None

attn_output = _flash_attention_forward(
query_states,
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def forward(
attention_mask,
q_len,
dropout=dropout_rate,
sliding_windows=getattr(self.config, "sliding_window", None),
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down

0 comments on commit 54a9fb0

Please sign in to comment.