Skip to content

Commit

Permalink
leaner modular diffllama
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 6, 2024
1 parent c5741eb commit ea622ce
Showing 1 changed file with 35 additions and 84 deletions.
119 changes: 35 additions & 84 deletions src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,7 @@ class DiffLlamaRMSNorm(LlamaRMSNorm):


class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DiffLlamaConfig] = None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor, rope_type, config)
pass


class DiffLlamaMLP(MistralMLP):
Expand Down Expand Up @@ -117,19 +107,16 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)

# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, target_len, _ = hidden_states.size()
Expand All @@ -143,16 +130,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -216,13 +194,13 @@ def __init__(self, *args, **kwargs):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down Expand Up @@ -359,13 +337,13 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
Expand Down Expand Up @@ -395,16 +373,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -468,77 +437,59 @@ def forward(


class DiffLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
pass
# def __init__(self, config: DiffLlamaConfig, layer_idx: int):
# super().__init__(config, layer_idx)

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
# self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = DiffLlamaMLP(config)
self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.mlp = DiffLlamaMLP(config)
# self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
config_class = DiffLlamaConfig
_no_split_modules = ["DiffLlamaDecoderLayer"]
pass
# config_class = DiffLlamaConfig
# _no_split_modules = ["DiffLlamaDecoderLayer"]


class DiffLlamaModel(DiffLlamaPreTrainedModel, LlamaModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`]
pass
# """
# Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`]

Args:
config: DiffLlamaConfig
"""
# Args:
# config: DiffLlamaConfig
# """

def __init__(self, config: DiffLlamaConfig):
super().__init__(config)
# def __init__(self, config: DiffLlamaConfig):
# super().__init__(config)

self.layers = nn.ModuleList(
[DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
# self.layers = nn.ModuleList(
# [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
# )
# self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)

# Initialize weights and apply final processing
self.post_init()
# # Initialize weights and apply final processing
# self.post_init()


class DiffLlamaForCausalLM(GemmaForCausalLM):
def __init__(self, config):
super().__init__(config)

self.model = DiffLlamaModel(config)

# Initialize weights and apply final processing
self.post_init()
pass


class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.model = DiffLlamaModel(config)

# Initialize weights and apply final processing
self.post_init()
pass


class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
def __init__(self, config):
super().__init__(config)
self.transformer = DiffLlamaModel(config)

# Initialize weights and apply final processing
self.post_init()
pass


class DiffLlamaForTokenClassification(LlamaForTokenClassification):
def __init__(self, config):
super().__init__(config)
self.model = DiffLlamaModel(config)

# Initialize weights and apply final processing
self.post_init()
pass


__all__ = [
Expand Down

0 comments on commit ea622ce

Please sign in to comment.