diff --git a/src/transformers/models/deci/configuration_deci.py b/src/transformers/models/deci/configuration_deci.py index 4cfc1df5b37fff..615061afa5450c 100644 --- a/src/transformers/models/deci/configuration_deci.py +++ b/src/transformers/models/deci/configuration_deci.py @@ -122,7 +122,40 @@ def __init__( rope_theta=10000.0, sliding_window=4096, attention_dropout=0.0, - num_key_value_heads_per_layer=(4,4,4,4,4,2,2,2,2,2,4,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,1,1,1,1,1,4), # fmt: skip + num_key_value_heads_per_layer=( + 4, + 4, + 4, + 4, + 4, + 2, + 2, + 2, + 2, + 2, + 4, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 4, + ), # fmt: skip **kwargs, ): self.vocab_size = vocab_size diff --git a/src/transformers/models/deci/modeling_deci.py b/src/transformers/models/deci/modeling_deci.py index 2859f65c0fda7b..97acc18f829745 100644 --- a/src/transformers/models/deci/modeling_deci.py +++ b/src/transformers/models/deci/modeling_deci.py @@ -522,6 +522,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Deci class DeciSdpaAttention(DeciAttention): """ @@ -610,9 +611,9 @@ def forward( return attn_output, None, past_key_value -LLAMA_ATTENTION_CLASSES = { +DECI_ATTENTION_CLASSES = { "eager": DeciAttention, - "flash_attention_2":DeciFlashAttention2, + "flash_attention_2": DeciFlashAttention2, "sdpa": DeciSdpaAttention, }