From 4660c6e3df897b56bfbf1654e3def61d8aad2a04 Mon Sep 17 00:00:00 2001 From: weak-kajuma Date: Sat, 21 Dec 2024 05:56:53 +0000 Subject: [PATCH] resolve missing docstring entries --- .../diffllama/configuration_diffllama.py | 22 +-- .../models/diffllama/modeling_diffllama.py | 159 ++++-------------- .../models/diffllama/modular_diffllama.py | 24 +-- utils/check_config_docstrings.py | 1 - 4 files changed, 44 insertions(+), 162 deletions(-) diff --git a/src/transformers/models/diffllama/configuration_diffllama.py b/src/transformers/models/diffllama/configuration_diffllama.py index 7304a149967dc0..1b38f55d3903f8 100644 --- a/src/transformers/models/diffllama/configuration_diffllama.py +++ b/src/transformers/models/diffllama/configuration_diffllama.py @@ -24,7 +24,8 @@ class DiffLlamaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama - model according to the specified arguments, defining the model architecture. + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults + will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -34,11 +35,11 @@ class DiffLlamaConfig(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DiffLlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): + hidden_size (`int`, *optional*, defaults to 2048): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): + intermediate_size (`int`, *optional*, defaults to 8192): Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): + num_hidden_layers (`int`, *optional*, defaults to 16): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. @@ -53,11 +54,10 @@ class DiffLlamaConfig(PretrainedConfig): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. DiffLlama 1 supports up to 2048 tokens, - DiffLlama 2 up to 4096, CodeDiffLlama up to 16384. + The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): + rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -137,15 +137,15 @@ class DiffLlamaConfig(PretrainedConfig): def __init__( self, vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=16, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, - rms_norm_eps=1e-6, + rms_norm_eps=1e-5, use_cache=True, pad_token_id=None, bos_token_id=1, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 87157d604257dd..d719b8224e3bb4 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,7 +31,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,9 +39,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -55,117 +53,10 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "meta-diffllama/DiffLlama-2-7b-hf" +_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" _CONFIG_FOR_DOC = "DiffLlamaConfig" -class DiffLlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - DiffLlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class DiffLlamaRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[DiffLlamaConfig] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class DiffLlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -589,6 +480,26 @@ def forward( return attn_output, None, past_key_value +class DiffLlamaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst DiffLlama is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + DIFFLLAMA_ATTENTION_CLASSES = { "eager": DiffLlamaAttention, "flash_attention_2": DiffLlamaFlashAttention2, @@ -600,9 +511,7 @@ class DiffLlamaDecoderLayer(nn.Module): def __init__(self, config: DiffLlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = DiffLlamaMLP(config) self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -616,7 +525,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -634,9 +542,6 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model @@ -654,7 +559,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -818,7 +722,6 @@ def __init__(self, config: DiffLlamaConfig): [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False if getattr(config, "pretraining_tp", 1) != 1: @@ -846,7 +749,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -868,9 +770,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False + return_legacy_cache = False # noqa: F841 if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + return_legacy_cache = True # noqa: F841 if past_key_values is None: past_key_values = DynamicCache() else: @@ -886,16 +788,22 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + + # embed positions hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized + # DiffLlama downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer # decoder layers all_hidden_states = () if output_hidden_states else None @@ -916,7 +824,6 @@ def forward( output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( @@ -927,8 +834,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 90feffb7f1b84b..40e728850d9cb8 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -48,20 +48,10 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" _CONFIG_FOR_DOC = "DiffLlamaConfig" -class DiffLlamaRMSNorm(LlamaRMSNorm): - pass - - -ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm) - - -class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding): - pass - - class DiffLlamaMLP(MistralMLP): pass @@ -429,18 +419,6 @@ def forward( return attn_output, None, past_key_value -class DiffLlamaDecoderLayer(LlamaDecoderLayer): - pass - - -class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): - pass - - -class DiffLlamaModel(DiffLlamaPreTrainedModel, LlamaModel): - pass - - class DiffLlamaForCausalLM(GemmaForCausalLM): pass diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index c1525eb7fbf78b..d243dd0c35b612 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -35,7 +35,6 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { - "DiffLlamaConfig", "DecisionTransformerConfig", "EncoderDecoderConfig", "MusicgenConfig",