Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 14, 2023
1 parent 18864b3 commit afc56bb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
35 changes: 34 additions & 1 deletion src/transformers/models/deci/configuration_deci.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,40 @@ def __init__(
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
num_key_value_heads_per_layer=(4,4,4,4,4,2,2,2,2,2,4,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,1,1,1,1,1,4), # fmt: skip
num_key_value_heads_per_layer=(
4,
4,
4,
4,
4,
2,
2,
2,
2,
2,
4,
2,
2,
2,
2,
2,
2,
2,
2,
2,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
4,
), # fmt: skip
**kwargs,
):
self.vocab_size = vocab_size
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/deci/modeling_deci.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Deci
class DeciSdpaAttention(DeciAttention):
"""
Expand Down Expand Up @@ -610,9 +611,9 @@ def forward(
return attn_output, None, past_key_value


LLAMA_ATTENTION_CLASSES = {
DECI_ATTENTION_CLASSES = {
"eager": DeciAttention,
"flash_attention_2":DeciFlashAttention2,
"flash_attention_2": DeciFlashAttention2,
"sdpa": DeciSdpaAttention,
}

Expand Down

0 comments on commit afc56bb

Please sign in to comment.