From ea622ce11b860def500578f3768ca82ccc6ec8cd Mon Sep 17 00:00:00 2001 From: weak-kajuma Date: Fri, 6 Dec 2024 11:49:12 +0000 Subject: [PATCH] leaner modular diffllama --- .../models/diffllama/modular_diffllama.py | 119 ++++++------------ 1 file changed, 35 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 672ecb81d82b59..e231a064d6bf35 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -59,17 +59,7 @@ class DiffLlamaRMSNorm(LlamaRMSNorm): class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[DiffLlamaConfig] = None, - ): - super().__init__(dim, max_position_embeddings, base, device, scaling_factor, rope_type, config) + pass class DiffLlamaMLP(MistralMLP): @@ -117,19 +107,16 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, target_len, _ = hidden_states.size() @@ -143,16 +130,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -216,13 +194,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -359,13 +337,13 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -395,16 +373,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -468,77 +437,59 @@ def forward( class DiffLlamaDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: DiffLlamaConfig, layer_idx: int): - super().__init__(config, layer_idx) + pass + # def __init__(self, config: DiffLlamaConfig, layer_idx: int): + # super().__init__(config, layer_idx) - self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + # self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = DiffLlamaMLP(config) - self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.mlp = DiffLlamaMLP(config) + # self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): - config_class = DiffLlamaConfig - _no_split_modules = ["DiffLlamaDecoderLayer"] + pass + # config_class = DiffLlamaConfig + # _no_split_modules = ["DiffLlamaDecoderLayer"] class DiffLlamaModel(DiffLlamaPreTrainedModel, LlamaModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`] + pass + # """ + # Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`] - Args: - config: DiffLlamaConfig - """ + # Args: + # config: DiffLlamaConfig + # """ - def __init__(self, config: DiffLlamaConfig): - super().__init__(config) + # def __init__(self, config: DiffLlamaConfig): + # super().__init__(config) - self.layers = nn.ModuleList( - [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) + # self.layers = nn.ModuleList( + # [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + # ) + # self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) - # Initialize weights and apply final processing - self.post_init() + # # Initialize weights and apply final processing + # self.post_init() class DiffLlamaForCausalLM(GemmaForCausalLM): - def __init__(self, config): - super().__init__(config) - - self.model = DiffLlamaModel(config) - - # Initialize weights and apply final processing - self.post_init() + pass class DiffLlamaForSequenceClassification(LlamaForSequenceClassification): - def __init__(self, config): - super().__init__(config) - self.model = DiffLlamaModel(config) - - # Initialize weights and apply final processing - self.post_init() + pass class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering): - def __init__(self, config): - super().__init__(config) - self.transformer = DiffLlamaModel(config) - - # Initialize weights and apply final processing - self.post_init() + pass class DiffLlamaForTokenClassification(LlamaForTokenClassification): - def __init__(self, config): - super().__init__(config) - self.model = DiffLlamaModel(config) - - # Initialize weights and apply final processing - self.post_init() + pass __all__ = [