From 3f85c22829b83dbd428e086113af182b5a485780 Mon Sep 17 00:00:00 2001 From: weak-kajuma Date: Fri, 6 Dec 2024 11:59:14 +0000 Subject: [PATCH] remove more and more in modular_diffllama.pt --- .../models/diffllama/modeling_diffllama.py | 156 ++++++------------ .../models/diffllama/modular_diffllama.py | 38 +---- 2 files changed, 54 insertions(+), 140 deletions(-) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index a4cce31a3b8732..87157d604257dd 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -25,15 +25,13 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -43,10 +41,11 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -54,8 +53,10 @@ from .configuration_diffllama import DiffLlamaConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-diffllama/DiffLlama-2-7b-hf" +_CONFIG_FOR_DOC = "DiffLlamaConfig" class DiffLlamaRMSNorm(nn.Module): @@ -78,9 +79,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -logger = logging.get_logger(__name__) - - class DiffLlamaRotaryEmbedding(nn.Module): def __init__( self, @@ -168,6 +166,20 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class DiffLlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -202,20 +214,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class DiffLlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -269,19 +267,16 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, target_len, _ = hidden_states.size() @@ -295,16 +290,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -368,13 +354,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -511,13 +497,13 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -547,16 +533,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -744,9 +721,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "DiffLlamaConfig" - - DIFFLLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -840,13 +814,15 @@ def __init__(self, config: DiffLlamaConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -870,6 +846,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -925,7 +902,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -951,6 +928,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1054,6 +1032,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1103,10 +1082,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) - self.model = DiffLlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1148,6 +1127,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1205,18 +1185,7 @@ def forward( loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1320,27 +1289,8 @@ def forward( loss = None if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1391,6 +1341,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1422,29 +1373,16 @@ def forward( start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() - total_loss = None + loss = None if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) if not return_dict: output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output + return ((loss,) + output) if loss is not None else output return QuestionAnsweringModelOutput( - loss=total_loss, + loss=loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, @@ -1483,6 +1421,11 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1521,8 +1464,7 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = self.loss_function(logits, labels, self.config) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index e231a064d6bf35..9ea3db372a5118 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -429,51 +429,23 @@ def forward( return attn_output, None, past_key_value -DIFFLLAMA_ATTENTION_CLASSES = { - "eager": DiffLlamaAttention, - "flash_attention_2": DiffLlamaFlashAttention2, - "sdpa": DiffLlamaSdpaAttention, -} +# DIFFLLAMA_ATTENTION_CLASSES = { +# "eager": DiffLlamaAttention, +# "flash_attention_2": DiffLlamaFlashAttention2, +# "sdpa": DiffLlamaSdpaAttention, +# } class DiffLlamaDecoderLayer(LlamaDecoderLayer): pass - # def __init__(self, config: DiffLlamaConfig, layer_idx: int): - # super().__init__(config, layer_idx) - - # self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - - # self.mlp = DiffLlamaMLP(config) - # self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): pass - # config_class = DiffLlamaConfig - # _no_split_modules = ["DiffLlamaDecoderLayer"] class DiffLlamaModel(DiffLlamaPreTrainedModel, LlamaModel): pass - # """ - # Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`] - - # Args: - # config: DiffLlamaConfig - # """ - - # def __init__(self, config: DiffLlamaConfig): - # super().__init__(config) - - # self.layers = nn.ModuleList( - # [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - # ) - # self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) - - # # Initialize weights and apply final processing - # self.post_init() class DiffLlamaForCausalLM(GemmaForCausalLM):