From 1d828b67593e1347c416eca2490de411d4cfbd8f Mon Sep 17 00:00:00 2001 From: weak-kajuma Date: Sat, 21 Dec 2024 11:16:57 +0000 Subject: [PATCH] remove past_key_value return in attention --- .../models/diffllama/modeling_diffllama.py | 179 +++++++++--------- .../models/diffllama/modular_diffllama.py | 9 +- 2 files changed, 94 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 696a4a4ae960ab..44924447e486de 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -226,7 +226,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DiffLlamaFlashAttention2(DiffLlamaAttention): @@ -376,7 +376,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DiffLlamaSdpaAttention(DiffLlamaAttention): @@ -479,7 +479,7 @@ def forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None class DiffLlamaRMSNorm(nn.Module): @@ -502,6 +502,92 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(nn.Module): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = DiffLlamaMLP(config) + self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + class DiffLlamaRotaryEmbedding(nn.Module): def __init__( self, @@ -589,93 +675,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -DIFFLLAMA_ATTENTION_CLASSES = { - "eager": DiffLlamaAttention, - "flash_attention_2": DiffLlamaFlashAttention2, - "sdpa": DiffLlamaSdpaAttention, -} - - -class DiffLlamaDecoderLayer(nn.Module): - def __init__(self, config: DiffLlamaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - - self.mlp = DiffLlamaMLP(config) - self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - DIFFLLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 74d168caa96f21..e84eab311f25ce 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -159,7 +159,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DiffLlamaFlashAttention2(DiffLlamaAttention): @@ -309,7 +309,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DiffLlamaSdpaAttention(DiffLlamaAttention): @@ -412,8 +412,8 @@ def forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - + return attn_output, None + DIFFLLAMA_ATTENTION_CLASSES = { "eager": DiffLlamaAttention, @@ -427,6 +427,7 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + class DiffLlamaModel(LlamaModel): pass