diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index cffc95cb47bdd2..74d168caa96f21 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -30,6 +30,7 @@ ) from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( + LlamaDecoderLayer, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, @@ -412,7 +413,19 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) class DiffLlamaModel(LlamaModel): pass