Skip to content

Commit

Permalink
add DiffLlamaDecoderLayer which is old, before #35235
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 21, 2024
1 parent 25e6235 commit 8843732
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from ..gemma.modeling_gemma import GemmaForCausalLM
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
Expand Down Expand Up @@ -412,7 +413,19 @@ def forward(
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value


DIFFLLAMA_ATTENTION_CLASSES = {
"eager": DiffLlamaAttention,
"flash_attention_2": DiffLlamaFlashAttention2,
"sdpa": DiffLlamaSdpaAttention,
}


class DiffLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

class DiffLlamaModel(LlamaModel):
pass
Expand Down

0 comments on commit 8843732

Please sign in to comment.