Skip to content

Commit

Permalink
resolve missing docstring entries
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 21, 2024
1 parent 87d034d commit 4660c6e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 162 deletions.
22 changes: 11 additions & 11 deletions src/transformers/models/diffllama/configuration_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
class DiffLlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
model according to the specified arguments, defining the model architecture.
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand All @@ -34,11 +35,11 @@ class DiffLlamaConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DiffLlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
num_hidden_layers (`int`, *optional*, defaults to 16):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
Expand All @@ -53,11 +54,10 @@ class DiffLlamaConfig(PretrainedConfig):
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. DiffLlama 1 supports up to 2048 tokens,
DiffLlama 2 up to 4096, CodeDiffLlama up to 16384.
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
Expand Down Expand Up @@ -137,15 +137,15 @@ class DiffLlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
Expand Down
159 changes: 32 additions & 127 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand All @@ -55,117 +53,10 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "meta-diffllama/DiffLlama-2-7b-hf"
_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
_CONFIG_FOR_DOC = "DiffLlamaConfig"


class DiffLlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DiffLlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class DiffLlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DiffLlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class DiffLlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -589,6 +480,26 @@ def forward(
return attn_output, None, past_key_value


class DiffLlamaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst DiffLlama is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"


DIFFLLAMA_ATTENTION_CLASSES = {
"eager": DiffLlamaAttention,
"flash_attention_2": DiffLlamaFlashAttention2,
Expand All @@ -600,9 +511,7 @@ class DiffLlamaDecoderLayer(nn.Module):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = DiffLlamaMLP(config)
self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -616,7 +525,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -634,9 +542,6 @@ def forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Expand All @@ -654,7 +559,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -818,7 +722,6 @@ def __init__(self, config: DiffLlamaConfig):
[DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
Expand Down Expand Up @@ -846,7 +749,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -868,9 +770,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
return_legacy_cache = True # noqa: F841
if past_key_values is None:
past_key_values = DynamicCache()
else:
Expand All @@ -886,16 +788,22 @@ def forward(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

# embed positions
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# DiffLlama downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -916,7 +824,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -927,8 +834,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
24 changes: 1 addition & 23 deletions src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,10 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
_CONFIG_FOR_DOC = "DiffLlamaConfig"


class DiffLlamaRMSNorm(LlamaRMSNorm):
pass


ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm)


class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass


class DiffLlamaMLP(MistralMLP):
pass

Expand Down Expand Up @@ -429,18 +419,6 @@ def forward(
return attn_output, None, past_key_value


class DiffLlamaDecoderLayer(LlamaDecoderLayer):
pass


class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
pass


class DiffLlamaModel(DiffLlamaPreTrainedModel, LlamaModel):
pass


class DiffLlamaForCausalLM(GemmaForCausalLM):
pass

Expand Down
1 change: 0 additions & 1 deletion utils/check_config_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@


CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DiffLlamaConfig",
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"MusicgenConfig",
Expand Down

0 comments on commit 4660c6e

Please sign in to comment.