From 924c46d40c20c8ce2599b5ecfb18173f97c3de16 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:29:31 +0200 Subject: [PATCH] Cohere: Fix copied from (#31213) Update modeling_cohere.py --- src/transformers/models/cohere/modeling_cohere.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7d1b0e19fc4df6..9f01450bb6ed0d 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -310,7 +310,7 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 Llama->Cohere +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere class CohereFlashAttention2(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays @@ -326,6 +326,7 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + # Ignore copy def forward( self, hidden_states: torch.Tensor,