Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 21, 2024
1 parent 396ebfd commit 5f1063a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 29 deletions.
143 changes: 119 additions & 24 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand All @@ -57,6 +59,93 @@
_CONFIG_FOR_DOC = "DiffLlamaConfig"


class DiffLlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DiffLlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class DiffLlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -481,23 +570,23 @@ def forward(


class DiffLlamaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, hidden_size, eps=1e-6):
"""
DiffLlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst DiffLlama is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


DIFFLLAMA_ATTENTION_CLASSES = {
Expand All @@ -511,7 +600,9 @@ class DiffLlamaDecoderLayer(nn.Module):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = DiffLlamaMLP(config)
self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -525,6 +616,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -542,6 +634,9 @@ def forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Expand All @@ -559,6 +654,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -722,6 +818,7 @@ def __init__(self, config: DiffLlamaConfig):
[DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
Expand Down Expand Up @@ -749,6 +846,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -770,9 +868,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False # noqa: F841
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True # noqa: F841
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
Expand All @@ -788,22 +886,16 @@ def forward(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

# embed positions
hidden_states = inputs_embeds

# normalized
# DiffLlama downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -824,6 +916,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -834,6 +927,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,16 @@

from ...cache_utils import Cache, StaticCache
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ..gemma.modeling_gemma import GemmaForCausalLM
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
Expand All @@ -52,6 +48,10 @@
_CONFIG_FOR_DOC = "DiffLlamaConfig"


class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass


class DiffLlamaMLP(MistralMLP):
pass

Expand Down Expand Up @@ -419,6 +419,10 @@ def forward(
return attn_output, None, past_key_value


class DiffLlamaModel(LlamaModel):
pass


class DiffLlamaForCausalLM(GemmaForCausalLM):
pass

Expand All @@ -436,7 +440,7 @@ class DiffLlamaForTokenClassification(LlamaForTokenClassification):


__all__ = [
"DiffLlamaPreTrainedModel",
"DiffLlamaPreTrainedModel", # noqa: F822
"DiffLlamaModel",
"DiffLlamaForCausalLM",
"DiffLlamaForSequenceClassification",
Expand Down

0 comments on commit 5f1063a

Please sign in to comment.