diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7d1b0e19fc4df6..9f01450bb6ed0d 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -310,7 +310,7 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 Llama->Cohere +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere class CohereFlashAttention2(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays @@ -326,6 +326,7 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + # Ignore copy def forward( self, hidden_states: torch.Tensor,