From 4e25753420decbd67f7a18330a05bdb3b610b97b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 20:59:48 +0100 Subject: [PATCH] be more explicit --- .../modular-transformers/modeling_dummy.py | 4 +- .../modeling_multimodal1.py | 4 +- .../modeling_my_new_model2.py | 4 +- .../modular-transformers/modeling_super.py | 4 +- .../integrations/flash_attention.py | 6 +- .../integrations/flex_attention.py | 6 +- .../integrations/sdpa_attention.py | 6 +- src/transformers/models/aria/modeling_aria.py | 4 +- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 24 ++-- .../models/gemma2/modular_gemma2.py | 24 ++-- src/transformers/models/glm/modeling_glm.py | 4 +- .../models/granite/modeling_granite.py | 4 +- .../models/llama/modeling_llama.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mistral/modular_mistral.py | 2 + .../models/mixtral/modeling_mixtral.py | 4 +- src/transformers/models/olmo/modeling_olmo.py | 9 +- src/transformers/models/olmo/modular_olmo.py | 10 +- .../models/olmo2/modeling_olmo2.py | 111 +++++++++--------- .../models/olmo2/modular_olmo2.py | 27 ++--- src/transformers/models/phi/modeling_phi.py | 4 +- src/transformers/models/phi/modular_phi.py | 2 + .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen2/modular_qwen2.py | 2 + .../models/starcoder2/modeling_starcoder2.py | 4 +- .../models/starcoder2/modular_starcoder2.py | 2 + 27 files changed, 166 insertions(+), 121 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 0e73f298c48665..c672788aee54ca 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 9ee097e950f019..39b221b0a60355 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 4e88c5aa1d1306..0324e192d97fd8 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 530c25721d2ae9..a43dd0d4948e98 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index ed2d06a98cf517..b33286627da92d 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -14,14 +14,14 @@ def flash_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, target_dtype: torch.dtype = torch.float16, **kwargs, -): +) -> Tuple[torch.Tensor, None]: if attention_mask is not None: seq_len = attention_mask.shape[1] query = query[:, :, :seq_len] diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index dd4287921d2c37..eacfb2b568b55b 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -14,11 +14,11 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, softcap: Optional[float] = None, **kwargs, -): +) -> Tuple[torch.Tensor, torch.Tensor]: causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 3a90ef9af2824e..805b379d2e18e4 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -20,11 +20,11 @@ def sdpa_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, -): +) -> Tuple[torch.Tensor, None]: key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 232d169595fe8b..743367d7e5202a 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -481,7 +481,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -535,6 +535,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -569,6 +570,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 5e1bc5bcfac12f..4d6c5698a9c3b6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -207,7 +207,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -261,6 +261,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -295,6 +296,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ff02bc8c06b4de..62d82ee92cc820 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -140,25 +140,31 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if module.attn_logit_softcapping is not None: - attn_weights = attn_weights / module.attn_logit_softcapping + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * module.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -197,6 +203,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -231,6 +238,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c184d171337272..0cb9811e0e577d 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -205,25 +205,31 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if module.attn_logit_softcapping is not None: - attn_weights = attn_weights / module.attn_logit_softcapping + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * module.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -242,6 +248,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -276,6 +283,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index b1c85f5177ada2..16d9cd2e464c71 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -91,7 +91,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -192,6 +192,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -226,6 +227,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 5bcde06c46f208..b9a57e2f816057 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -134,7 +134,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -188,6 +188,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -222,6 +223,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 17691d72452c77..24d22481396620 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -211,7 +211,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -265,6 +265,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -299,6 +300,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 05c7bac2618624..6fb7ee21572f4a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -108,7 +108,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -153,6 +153,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -187,6 +188,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index a200d48085052c..855c87c363ee4c 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -49,6 +49,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -83,6 +84,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 27226dc7417a9c..e05f9177124d65 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -221,7 +221,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -266,6 +266,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -300,6 +301,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c518bc3a0fa623..02da8de50c1e0e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -134,7 +134,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -187,7 +187,8 @@ def __init__(self, config: OlmoConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -208,7 +209,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -231,6 +232,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -247,6 +249,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.hidden_size = config.hidden_size self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 6cd0f0d21ee89e..0ca1135ec31c79 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -53,7 +53,8 @@ class OlmoAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -74,7 +75,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -97,6 +98,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -110,10 +112,6 @@ def forward( class OlmoDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__(config, layer_idx) - self.hidden_size = config.hidden_size - - self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) - self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index d4a38f20ad1cc1..ccca0499c4c1c1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -103,7 +103,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -158,11 +158,9 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -177,7 +175,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -200,6 +198,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -232,6 +231,7 @@ def __init__(self, config: Olmo2Config, layer_idx: int): self.hidden_size = config.hidden_size self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) + self.mlp = Olmo2MLP(config) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +245,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -258,6 +259,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -272,54 +274,8 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs - -OLMO2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Olmo2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", - OLMO2_START_DOCSTRING, -) -class Olmo2PreTrainedModel(PreTrainedModel): - config_class = Olmo2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Olmo2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + return outputs class Olmo2RotaryEmbedding(nn.Module): @@ -387,6 +343,51 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +OLMO2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", + OLMO2_START_DOCSTRING, +) +class Olmo2PreTrainedModel(PreTrainedModel): + config_class = Olmo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + OLMO2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -732,7 +733,7 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: Olmo2Config): + def __init__(self, config): super().__init__(config) self.model = Olmo2Model(config) self.vocab_size = config.vocab_size diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 74873d57b40982..fab2db95f10fb7 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -173,11 +173,9 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -192,7 +190,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -215,6 +213,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -244,12 +243,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -257,6 +257,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -271,13 +272,8 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs - -class Olmo2PreTrainedModel(OlmoPreTrainedModel): - pass + return outputs # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of @@ -285,17 +281,12 @@ class Olmo2PreTrainedModel(OlmoPreTrainedModel): class Olmo2Model(OlmoModel): def __init__(self, config: Olmo2Config): super().__init__(config) - self.layers = nn.ModuleList( - [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # The heads now only need to redefine the model inside to the correct `RobertaModel` class Olmo2ForCausalLM(OlmoForCausalLM): - def __init__(self, config: Olmo2Config): - super().__init__(config) - self.model = Olmo2Model(config) + pass __all__ = [ diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0e409c0c953f32..e9b946f24780f8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -93,7 +93,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -157,6 +157,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -209,6 +210,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index b3f273298643d4..46fd51e9afcb0d 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -40,6 +40,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -92,6 +93,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8dd0659942b5bc..269075e1cfb57b 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -108,7 +108,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -153,6 +153,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -195,6 +196,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 47a785a27b76f2..46c8b0f8c658f7 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -47,6 +47,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -89,6 +90,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3c76732e0ee6ba..fcfbfd26bb26ca 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -127,7 +127,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -173,6 +173,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -207,6 +208,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index cb1122f4dbc44f..f5124e07e5dbae 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -91,6 +91,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -125,6 +126,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs,