From 43f3302e0c99792aa0057494ed3dd26dd402d27a Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Fri, 20 Dec 2024 15:16:28 +0100 Subject: [PATCH] Several fixes related to rotary position embeddings First part of resolution of #35233 - Changes related to `position_embeddings` being a mandatory argument - Remove `position_ids` argument of `apply_rotary_pos_emb` - Replace `torch.stack` by `torch.cat`, former requires equal shapes - `esm`: RoPE depends on `position_ids`, which was ignored. - `gpt_neox`: Selection of attention compute type via class removed - `gptj`: RoPE must be applied per head, and some shape issues. - `nemotron`: `config.partial_rotary_factor` was not implemented. --- src/transformers/models/aria/modeling_aria.py | 14 ++-- .../models/bamba/modeling_bamba.py | 2 + .../models/chameleon/modeling_chameleon.py | 4 +- .../models/codegen/modeling_codegen.py | 3 +- .../models/cohere/modeling_cohere.py | 25 ++++--- .../models/cohere2/modeling_cohere2.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 4 +- src/transformers/models/esm/modeling_esm.py | 69 ++++++++++++++----- .../models/falcon/modeling_falcon.py | 20 +++--- .../models/gemma/modeling_gemma.py | 10 +-- .../models/gemma2/modeling_gemma2.py | 4 +- src/transformers/models/glm/modeling_glm.py | 10 +-- .../models/gpt_neox/configuration_gpt_neox.py | 4 +- .../models/gpt_neox/modeling_gpt_neox.py | 36 ++++------ .../modeling_gpt_neox_japanese.py | 16 ++--- .../models/gptj/configuration_gptj.py | 4 +- src/transformers/models/gptj/modeling_gptj.py | 31 +++++++-- .../models/granite/modeling_granite.py | 14 ++-- .../models/granite/modular_granite.py | 8 +-- .../models/granitemoe/modeling_granitemoe.py | 27 +++++--- .../models/jamba/modeling_jamba.py | 1 + .../models/jetmoe/modeling_jetmoe.py | 5 +- .../models/llama/modeling_flax_llama.py | 3 +- .../models/llama/modeling_llama.py | 14 ++-- src/transformers/models/mimi/modeling_mimi.py | 4 +- .../models/mistral/modeling_flax_mistral.py | 3 +- .../models/mistral/modeling_mistral.py | 14 ++-- .../models/mistral/modular_mistral.py | 2 + .../models/mixtral/modeling_mixtral.py | 13 ++-- .../models/mixtral/modular_mixtral.py | 7 +- .../models/mllama/modeling_mllama.py | 4 +- .../models/moshi/modeling_moshi.py | 4 +- .../models/nemotron/configuration_nemotron.py | 5 +- .../models/nemotron/modeling_nemotron.py | 55 +++++++++++---- src/transformers/models/olmo/modeling_olmo.py | 12 ++-- .../models/olmo2/modeling_olmo2.py | 12 ++-- .../models/olmo2/modular_olmo2.py | 4 +- .../models/olmoe/modeling_olmoe.py | 27 +++++--- .../models/persimmon/modeling_persimmon.py | 16 ++--- src/transformers/models/phi/modeling_phi.py | 14 ++-- src/transformers/models/phi/modular_phi.py | 10 +-- .../models/phi3/configuration_phi3.py | 7 +- src/transformers/models/phi3/modeling_phi3.py | 8 +-- .../models/phimoe/modeling_phimoe.py | 32 ++++++--- .../models/pixtral/modeling_pixtral.py | 18 ++--- .../models/qwen2/modeling_qwen2.py | 14 ++-- .../models/qwen2/modular_qwen2.py | 2 + .../models/qwen2_moe/modeling_qwen2_moe.py | 5 +- .../models/qwen2_vl/configuration_qwen2_vl.py | 2 + .../models/qwen2_vl/modeling_qwen2_vl.py | 39 +++++------ .../configuration_recurrent_gemma.py | 6 +- .../modeling_recurrent_gemma.py | 19 +++-- .../models/stablelm/modeling_stablelm.py | 26 +++---- .../models/starcoder2/modeling_starcoder2.py | 10 ++- .../models/starcoder2/modular_starcoder2.py | 2 +- 55 files changed, 413 insertions(+), 314 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6481d6f3c434c7..8849bbfd320bc0 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -437,7 +437,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -445,8 +445,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -537,6 +535,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -619,13 +619,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -963,24 +963,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index c89d8d7853008d..1d68e7d3b589e9 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -305,6 +305,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 11bc411a00c005..4ef7cddae0521f 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -153,7 +153,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -161,8 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 616c93a46e4f4a..a7ffd7efd5151e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -48,8 +48,7 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + return torch.concat((-x2, x1), dim=-1) # Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7b8b9547ac1c33..e0af6c143b4290 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -176,8 +176,7 @@ def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -286,15 +285,17 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -372,15 +373,17 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " @@ -476,14 +479,16 @@ class CohereSdpaAttention(CohereAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -492,13 +497,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -581,13 +586,13 @@ def __init__(self, config: CohereConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -615,13 +620,13 @@ def forward( # Self Attention hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # Fully Connected @@ -861,24 +866,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 1ffa4bffddc3df..4057744675ed1f 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -161,8 +161,7 @@ def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 0d2c4297e0d473..6e8a6ac946f331 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -84,7 +84,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -92,8 +92,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 5df5435bb1229a..e7268ee1f00c83 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -88,31 +88,60 @@ def __init__(self, dim: int): super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - inv_freq = inv_freq self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None self._cos_cached = None self._sin_cached = None + self._positions_ids_cached = None - def _update_cos_sin_tables(self, x, seq_dimension=2): - seq_len = x.shape[seq_dimension] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: - self._seq_len_cached = seq_len - t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) + def _update_cos_sin_tables( + self, + x: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + seq_dimension: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Reset the tables if the sequence length has changed, position_ids + # has changed, or if we're on a new device (possibly due to tracing for + # instance) + device_changed = (self._cos_cached is not None) and (self._cos_cached.device != x.device) + t = None + if position_ids is not None: + if ( + device_changed + or self._positions_ids_cached is None + or not torch.equal(position_ids, self._positions_ids_cached) + ): + # RoPE embeddings depends on position_ids + # Caching makes sense: position_ids is the same for every layer + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + self._positions_ids_cached = torch.clone(position_ids) + t = position_ids.unsqueeze(-1).type_as(self.inv_freq).to(x.device) + else: + seq_len = x.shape[seq_dimension] + if device_changed or seq_len != self._seq_len_cached: + self._seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device)[None, :, None].type_as(self.inv_freq) + if t is not None: + inv_freq = self.inv_freq[None, None, :].expand(*t.shape[:2], -1) + t = t.expand(-1, -1, inv_freq.shape[-1]) + freqs = t * inv_freq emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - - self._cos_cached = emb.cos()[None, None, :, :] - self._sin_cached = emb.sin()[None, None, :, :] + self._cos_cached = emb.cos().unsqueeze(1) + self._sin_cached = emb.sin().unsqueeze(1) return self._cos_cached, self._sin_cached - def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + x=k, position_ids=position_ids, seq_dimension=-2 + ) return ( apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), @@ -284,6 +313,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -334,7 +364,7 @@ def forward( past_key_value = (key_layer, value_layer) if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer, position_ids) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -426,6 +456,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -436,6 +467,7 @@ def forward( self_outputs = self.self( hidden_states_ln, attention_mask, + position_ids, head_mask, encoder_hidden_states, encoder_attention_mask, @@ -491,6 +523,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -502,6 +535,7 @@ def forward( self_attention_outputs = self.attention( hidden_states, attention_mask, + position_ids, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, @@ -569,6 +603,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -602,6 +637,7 @@ def forward( layer_module.__call__, hidden_states, attention_mask, + position_ids, layer_head_mask, encoder_hidden_states, encoder_attention_mask, @@ -612,6 +648,7 @@ def forward( layer_outputs = layer_module( hidden_states, attention_mask, + position_ids, layer_head_mask, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e0e4ff424cb47d..6167f80fe4e1a3 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -81,7 +81,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -89,8 +89,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -330,6 +328,7 @@ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -338,8 +337,9 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -480,6 +480,7 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -488,8 +489,9 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -618,6 +620,7 @@ def __init__(self, config: FalconConfig, layer_idx=None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -626,7 +629,6 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): residual = hidden_states @@ -640,6 +642,7 @@ def forward( # Self attention. attn_outputs = self.self_attention( attention_layernorm_out, + position_embeddings=position_embeddings, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, @@ -648,7 +651,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) attention_output = attn_outputs[0] @@ -961,6 +963,7 @@ def forward( outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, + position_embeddings, alibi, causal_mask, position_ids, @@ -969,11 +972,11 @@ def forward( use_cache, output_attentions, cache_position, - position_embeddings, ) else: outputs = block( hidden_states, + position_embeddings=position_embeddings, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, @@ -982,7 +985,6 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e2ea12b03fe434..8ddd4dbb0b47da 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -163,7 +163,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -171,8 +171,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -263,6 +261,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -318,13 +318,13 @@ def __init__(self, config: GemmaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -334,13 +334,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 67fc6c86a3bac6..8c50c5a2ac5c3a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -97,7 +97,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -105,8 +105,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 95ad0d9719951d..23b05a877d027d 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -194,6 +194,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -334,13 +336,13 @@ def __init__(self, config: GlmConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -350,13 +352,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -595,24 +597,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 07514a37c6f2fa..c632085ff092e6 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -49,7 +49,9 @@ class GPTNeoXConfig(PretrainedConfig): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. rotary_pct (`float`, *optional*, defaults to 0.25): - percentage of hidden dimensions to allocate to rotary embeddings + Percentage of hidden dimensions to allocate to rotary embeddings. + Note: In most other models, this parameter is called + `partial_rotary_factor`. rotary_emb_base (`int`, *optional*, defaults to 10000) base for computing rotary embeddings frequency attention_dropout (`float`, *optional*, defaults to 0.0): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7152d72f5b7fc8..3ab8191d12a78b 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -269,12 +269,10 @@ def __init__(self, config, layer_idx=None): "The hidden size is not divisble by the number of attention heads! Make sure to update them" ) self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - self.rope_theta = config.rotary_emb_base + self.rotary_ndims = int(self.head_size * config.partial_rotary_factor) self._init_bias(config.max_position_embeddings) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) - self.rotary_emb = GPTNeoXRotaryEmbedding(config=self.config) if layer_idx is None: logger.warning_once( @@ -303,6 +301,7 @@ def _init_bias(self, max_positions, device=None): def forward( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, head_mask: Optional[torch.FloatTensor] = None, @@ -311,18 +310,19 @@ def forward( output_attentions: Optional[bool] = False, padding_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, seq_len, _ = hidden_states.shape # Apply attention-specific projections and rope query, key, value, present = self._attn_projections_and_rope( hidden_states=hidden_states, + position_embeddings=position_embeddings, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # Checking for fallbacks in case an unsupported feature is requested @@ -400,12 +400,14 @@ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): def _attn_projections_and_rope( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: torch.LongTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") # Compute QKV # Attention heads [batch, seq_len, hidden_size] # --> [batch, seq_len, (np * 3 * head_size)] @@ -560,7 +562,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -568,8 +570,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -601,14 +601,6 @@ def forward(self, hidden_states): return hidden_states -GPT_NEOX_ATTENTION_CLASSES = { - "eager": GPTNeoXAttention, - "flash_attention_2": GPTNeoXFlashAttention2, - "sdpa": GPTNeoXSdpaAttention, - "flex_attention": GPTNeoXAttention, -} - - class GPTNeoXLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -617,12 +609,13 @@ def __init__(self, config, layer_idx): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_dropout = nn.Dropout(config.hidden_dropout) self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) - self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.attention = GPTNeoXAttention(config, layer_idx) self.mlp = GPTNeoXMLP(config) def forward( self, hidden_states: Optional[torch.FloatTensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -630,10 +623,10 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): attention_layer_outputs = self.attention( self.input_layernorm(hidden_states), + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, @@ -641,7 +634,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = self.post_attention_dropout(attn_output) @@ -875,6 +867,7 @@ def forward( outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, head_mask[i], @@ -882,11 +875,11 @@ def forward( None, output_attentions, cache_position, - position_embeddings, ) else: outputs = layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], @@ -894,7 +887,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 71602f01e7d6f8..a37c8372628355 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -83,9 +83,7 @@ def __init__(self, config, use_bias=False, layer_idx=None): ) self.layer_idx = layer_idx - self.rotary_ndims = int(self.head_size * config.rotary_pct) - self.rope_theta = config.rotary_emb_base - self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config) + self.rotary_ndims = int(self.head_size * config.partial_rotary_factor) self.attention_dropout = nn.Dropout(config.attention_dropout) self.norm_factor = math.sqrt(self.head_size) @@ -98,6 +96,7 @@ def __init__(self, config, use_bias=False, layer_idx=None): def forward( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, head_mask: Optional[torch.FloatTensor] = None, @@ -105,7 +104,6 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -297,7 +295,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -305,8 +303,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -377,6 +373,7 @@ def __init__(self, config, layer_number): def forward( self, hidden_states: Optional[torch.FloatTensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -384,12 +381,12 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): residual = hidden_states ln_out = self.input_layernorm(hidden_states) attention_layer_outputs, attn_bias = self.attention( ln_out, + position_embeddings=position_embeddings, attention_mask=attention_mask, layer_past=layer_past, head_mask=head_mask, @@ -397,7 +394,6 @@ def forward( output_attentions=output_attentions, position_ids=position_ids, cache_position=cache_position, - position_embeddings=position_embeddings, ) attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) outputs = attention_layer_outputs[1:] @@ -623,6 +619,7 @@ def forward( outputs = layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], @@ -630,7 +627,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 1b93f259b05b12..1c8f31e0394ec1 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -49,7 +49,9 @@ class GPTJConfig(PretrainedConfig): n_head (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. rotary_dim (`int`, *optional*, defaults to 64): - Number of dimensions in the embedding that Rotary Position Embedding is applied to. + Number of dimensions in the embedding of each head that Rotary Position + Embedding is applied to. If `rotary_dim=None`, RoPE is applied to the + full size `n_embd // n_head` (this is not the default). n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu_new"`): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 4af8f73b5f5eea..448b4347686463 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -66,14 +66,13 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: @torch.fx.wrap def get_embed_positions(embed_positions, position_ids): - return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1) + return embed_positions.to(position_ids.device).unsqueeze(0).repeat(position_ids.shape[0], 1, 1) def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + return torch.concat((-x2, x1), dim=-1) def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: @@ -115,7 +114,8 @@ def __init__(self, config, layer_idx=None): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.rotary_dim = config.rotary_dim - pos_embd_dim = self.rotary_dim or self.embed_dim + pos_embd_dim = self.rotary_dim or self.head_dim + # `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)` self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): @@ -178,11 +178,22 @@ def _attn( return attn_output, attn_weights def _get_embed_positions(self, position_ids): + """ + This method does not subselect according to `position_ids`, it only + deals with device and shape. + + Args: + position_ids: Position indices + + Returns: + `embed_positions`, with device and shape according to `position_ids` + + """ embed_positions = self.embed_positions if embed_positions.device != position_ids.device: embed_positions = embed_positions.to(position_ids.device) self.embed_positions = embed_positions - return embed_positions.repeat(position_ids.shape[0], 1, 1) + return embed_positions.unsqueeze(0).repeat(position_ids.shape[0], 1, 1) def forward( self, @@ -198,6 +209,8 @@ def forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + if position_ids is None: + raise ValueError("position_ids must be given") query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) @@ -205,6 +218,9 @@ def forward( query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + # Attention, different shapes: + # query, key: (B, T, n_head, rotary_dim) + # value: (B, n_head, T, rotary_dim) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -235,6 +251,7 @@ def forward( key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) + # At this point, all have shape (B, n_head, T, rotary_dim) if layer_past is not None: cache_kwargs = { @@ -288,6 +305,8 @@ def forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + if position_ids is None: + raise ValueError("position_ids must be given") query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) @@ -323,7 +342,7 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - # tanspose to have the desired shape + # transpose to have the desired shape # before transpose: batch_size x seq_length x num_attention_heads x head_dim # after transpose: batch_size x num_attention_heads x seq_length x head_dim key = key.permute(0, 2, 1, 3) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d95de..3eb7a27ba19207 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -54,7 +54,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -62,8 +62,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -154,6 +152,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -245,13 +245,13 @@ def __init__(self, config: GraniteConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -283,13 +283,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -597,24 +597,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 698280085f1852..36ba842689abc1 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -48,13 +48,13 @@ def __init__(self, config: GraniteConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -86,13 +86,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -187,24 +187,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 1c4c06bbc8d71e..34b94986544381 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -82,6 +82,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -231,7 +232,7 @@ def rotate_half(x): # Copied from transformers.models.granite.modeling_granite.apply_rotary_pos_emb with Granite->GraniteMoe -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -239,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -457,15 +456,17 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -535,14 +536,16 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -636,15 +639,17 @@ class GraniteMoeSdpaAttention(GraniteMoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -653,13 +658,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -739,6 +744,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -746,7 +752,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -781,13 +786,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) @@ -1058,6 +1063,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -1065,11 +1071,11 @@ def forward( use_cache, cache_position, output_router_logits, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1077,7 +1083,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, output_router_logits=output_router_logits, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index ae7470d789b27e..0a7fc071b2e7c5 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -112,6 +112,7 @@ def load_balancing_loss_func( if router_logits is None or not isinstance(router_logits, tuple): return 0 + compute_device = None if isinstance(router_logits, tuple): compute_device = router_logits[0].device concatenated_router_logits = torch.cat( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7b7fd5a90d69ed..4507253468bf3f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -87,6 +87,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -459,7 +460,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -467,8 +468,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 26a2c2bb09a3d2..dee04225d0bee3 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -133,7 +133,8 @@ def create_sinusoidal_positions(num_pos, dim): freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + emb = emb[:, None, :] + out = np.concatenate((np.sin(emb), np.cos(emb)), axis=-1) return jnp.array(out[:, :, :num_pos]) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c26414cd7..c980381cd8c95c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -151,7 +151,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -159,8 +159,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -267,6 +265,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -322,13 +322,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -338,13 +338,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -583,24 +583,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 1440ce1e075c95..8ffc341804e6e8 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -438,7 +438,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -446,8 +446,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py index 3bff2a6281220e..b65d3803a29947 100644 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ b/src/transformers/models/mistral/modeling_flax_mistral.py @@ -203,7 +203,8 @@ def create_sinusoidal_positions(num_pos, dim): freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + emb = emb[:, None, :] + out = np.concatenate((np.sin(emb), np.cos(emb)), axis=-1) return jnp.array(out[:, :, :num_pos]) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90c38895b4280b..e9fc0f5698e424 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -64,7 +64,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -72,8 +72,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,6 +153,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -229,13 +229,13 @@ def __init__(self, config: MistralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -245,13 +245,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -555,24 +555,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 362233a21b70f4..c85ad9e730dc04 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -56,6 +56,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 84ed327d9be920..3183a3c3f2e857 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -177,7 +177,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -185,8 +185,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -268,6 +266,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -324,6 +324,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -331,7 +332,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -683,6 +683,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -690,11 +691,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -702,7 +703,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) @@ -915,6 +915,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a6069f69b33421..52cf2876852057 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -88,6 +88,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -246,6 +247,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -253,7 +255,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -392,6 +393,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -399,11 +401,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -411,7 +413,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 3e0c4d7a5123a7..37cd66d73dd5e9 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -617,7 +617,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -625,8 +625,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index f0281f57cf1c75..424a7b94d100fe 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -381,7 +381,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -389,8 +389,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/nemotron/configuration_nemotron.py b/src/transformers/models/nemotron/configuration_nemotron.py index 7690703127ac92..0720ede2cdb74d 100644 --- a/src/transformers/models/nemotron/configuration_nemotron.py +++ b/src/transformers/models/nemotron/configuration_nemotron.py @@ -76,7 +76,8 @@ class NemotronConfig(PretrainedConfig): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. + partial_rotary_factor (`float`, *optional*, defaults to 1.0): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -119,7 +120,7 @@ def __init__( eos_token_id=3, tie_word_embeddings=False, rope_theta=10000.0, - partial_rotary_factor=0.5, + partial_rotary_factor=1.0, attention_bias=False, attention_dropout=0.0, mlp_bias=False, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a0a10bdc6f3550..f50697c2093388 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -237,7 +237,8 @@ def __init__(self, config: NemotronConfig, layer_idx: Optional[int] = None): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta - self.partial_rotary_factor = config.partial_rotary_factor + head_size = config.hidden_size // config.num_attention_heads + self.rotary_ndims = int(head_size * config.partial_rotary_factor) self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) @@ -256,6 +257,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -266,9 +269,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -330,6 +339,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " @@ -351,9 +362,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -438,6 +455,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -446,13 +465,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -465,9 +484,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -533,13 +558,13 @@ def __init__(self, config: NemotronConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -571,13 +596,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -823,24 +848,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f72c9..630e08e79efc50 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -70,7 +70,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -78,8 +78,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -233,13 +231,13 @@ def __init__(self, config: OlmoConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -559,24 +557,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f1101..c07d9338e90626 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -59,7 +59,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -67,8 +67,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -235,13 +233,13 @@ def __init__(self, config: Olmo2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -560,24 +558,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 5f119170804466..c5b74f41a8544c 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -234,13 +234,13 @@ def __init__(self, config: Olmo2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -248,13 +248,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index fa3c2f3cd4d11b..44abe4f6d81e48 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -83,6 +83,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -231,7 +232,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -239,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -330,15 +329,17 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_norm(self.q_proj(hidden_states)) @@ -412,15 +413,17 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -513,14 +516,16 @@ class OlmoeSdpaAttention(OlmoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -529,13 +534,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -666,6 +671,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -673,7 +679,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -708,13 +713,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -981,6 +986,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -988,11 +994,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1000,7 +1006,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d3c20b9ace717..5c2261d75f6681 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -130,7 +130,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -138,8 +138,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -231,14 +229,16 @@ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() # [batch_size, seq_length, 3 x hidden_size] @@ -326,13 +326,13 @@ def __init__(self, config: PersimmonConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -365,13 +365,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -626,24 +626,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 477896decd5318..180bf0fcd7bb60 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -47,7 +47,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -55,8 +55,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -147,6 +145,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -232,13 +232,13 @@ def __init__(self, config: PhiConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -248,13 +248,13 @@ def forward( # Self Attention attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -557,24 +557,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 0faa4629f1a768..a4ed37f0596864 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -54,6 +54,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -128,13 +130,13 @@ def __init__(self, config: PhiConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -144,13 +146,13 @@ def forward( # Self Attention attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -242,24 +244,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 4940f43e5bffe3..f87f0b55940a95 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -15,6 +15,8 @@ """Phi-3 model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -204,7 +206,8 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" ) - if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + required_size = int(math.ceil((self.hidden_size // self.num_attention_heads) / 2)) + if not len(rope_scaling_short_factor) == required_size: raise ValueError( f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" ) @@ -215,7 +218,7 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" ) - if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + if not len(rope_scaling_long_factor) == required_size: raise ValueError( f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 908fd982b9c73c..e9c790c7b20482 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -80,11 +80,11 @@ class Phi3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base + self.dim = dim - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() @@ -241,7 +241,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -249,8 +249,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 8f6b092da6e6ad..9ffeafd4da30c8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -91,6 +91,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -156,8 +157,11 @@ def __init__( else: self.rope_type = "default" self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + self.rotary_ndims = config.hidden_size // config.num_attention_heads - def forward(self, x, seq_len=None): + def forward(self, x, seq_len: int): + if seq_len is None: + raise ValueError("seq_len must be given") mscale = None if self.config.rope_scaling and seq_len: mscale = ( @@ -171,7 +175,9 @@ def forward(self, x, seq_len=None): freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype) + cos = emb.cos() + sin = emb.sin() + return (cos * mscale).to(x.dtype), (sin * mscale).to(x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -270,14 +276,16 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -338,14 +346,16 @@ class PhimoeFlashAttention2(PhimoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -432,14 +442,16 @@ class PhimoeSdpaAttention(PhimoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -448,12 +460,12 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -813,6 +825,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -820,7 +833,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -852,13 +864,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -1116,6 +1128,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -1123,11 +1136,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1135,7 +1148,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 03886d4a528478..fc8258c92093ed 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -127,7 +127,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -135,8 +135,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -175,11 +173,13 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") batch_size, patches, _ = hidden_states.size() @@ -261,8 +261,8 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ @@ -280,8 +280,8 @@ def forward( hidden_states = self.attention_norm(hidden_states) hidden_states, attn_weights = self.attention( hidden_states=hidden_states, - attention_mask=attention_mask, position_embeddings=position_embeddings, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states @@ -310,8 +310,8 @@ def __init__(self, config): def forward( self, inputs_embeds, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -353,15 +353,15 @@ def forward( layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, - attention_mask, position_embeddings, + attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask, position_embeddings=position_embeddings, + attention_mask=attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36fb1ddf1390ac..055a72e73a7686 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -64,7 +64,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -72,8 +72,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,6 +153,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -242,13 +242,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -258,13 +258,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -568,24 +568,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 718abd01090c2b..ccad8d533bfd98 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -52,6 +52,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 1ce41509a5c0d1..f96c7799db5584 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -94,6 +94,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -240,7 +241,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -248,8 +249,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index ef98ae5e3f508f..fb2102491aebf5 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -54,6 +54,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size +# TODO: Add comment for `rope_scaling["mrope_section"]`. This parameter +# is mandatory, but it is unclear what it should be set to. class Qwen2VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 566141d3f75c27..92f33481410e27 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -217,17 +217,14 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. mrope_section(`List(int)`): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and + sin so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. @@ -524,23 +521,19 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = Qwen2VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -617,14 +610,16 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -719,14 +714,16 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -735,13 +732,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -825,13 +822,13 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -863,13 +860,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -1120,24 +1117,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py index 7f45a41710cf29..e59f1b16d4fc1a 100644 --- a/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py @@ -14,6 +14,8 @@ # limitations under the License. """RecurrentGemma model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -155,4 +157,6 @@ def __init__( @property def layers_block_type(self): - return (self.block_types * 100)[: self.num_hidden_layers] + len_bt = len(self.block_types) + sz = int(math.ceil(self.num_hidden_layers / len_bt)) + return (self.block_types * sz)[: self.num_hidden_layers] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 74fc2085c36519..6da1fc69c026ad 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -71,9 +71,8 @@ def extra_repr(self): class RecurrentGemmaRotaryEmbedding(nn.Module): def __init__(self, dim, base=10000, device=None): super().__init__() - self.dim = dim self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() @@ -103,7 +102,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -111,8 +110,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,14 +152,14 @@ def __init__(self, config: RecurrentGemmaConfig): self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.partial_rotary_factor = config.partial_rotary_factor + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) self.rotary_emb = RecurrentGemmaRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), + self.rotary_ndims, base=config.rope_theta, ) @@ -187,9 +184,11 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding - query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) - key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 88dc437cdcb91d..eb18d3677788f8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -136,7 +136,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -144,8 +144,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -227,7 +225,6 @@ def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.is_causal = True @@ -249,19 +246,20 @@ def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): ) self.attention_dropout = nn.Dropout(config.attention_dropout) - self.rotary_emb = StableLmRotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -341,14 +339,16 @@ class StableLmSdpaAttention(StableLmAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -357,13 +357,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -463,13 +463,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # StableLmFlashAttention2 attention does not support output_attentions @@ -571,13 +571,13 @@ def __init__(self, config: StableLmConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -611,13 +611,13 @@ def forward( # Self Attention self_attn_output, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward @@ -881,24 +881,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3b4fdbcb81ccc4..814eeabad4fa85 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -83,7 +83,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -91,8 +91,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -233,13 +231,13 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -561,13 +559,13 @@ def forward( layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 32d64cd167ba50..61fc3271d81f5a 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -223,13 +223,13 @@ def forward( layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, )