diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 74ab6c8c979fb5..1d336845de235b 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -272,7 +272,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index b83b3036ae9fa8..78e37aa123d1ee 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -272,7 +272,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index fa89cc368380a3..cc26a94defc6a8 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -207,7 +207,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -245,6 +245,51 @@ def forward( return outputs +MY_NEW_MODEL2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MyNewModel2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", + MY_NEW_MODEL2_START_DOCSTRING, +) +class MyNewModel2PreTrainedModel(PreTrainedModel): + config_class = MyNewModel2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MyNewModel2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class MyNewModel2RotaryEmbedding(nn.Module): def __init__( self, @@ -310,51 +355,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MY_NEW_MODEL2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MyNewModel2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", - MY_NEW_MODEL2_START_DOCSTRING, -) -class MyNewModel2PreTrainedModel(PreTrainedModel): - config_class = MyNewModel2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["MyNewModel2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 0bec832589241b..cfbaa5ac6e6459 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -269,7 +269,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 62e030e90c4fa0..333a22bbcf98d0 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,13 +1,25 @@ +from typing import Optional + import torch from ..modeling_flash_attention_utils import _flash_attention_forward def flash_attention_forward( - config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + target_dtype: torch.dtype = torch.float16, + **kwargs, ): - if attentions_mask is not None: - seq_len = attentions_mask.shape[1] + if attention_mask is not None: + seq_len = attention_mask.shape[1] query = query[:, :, :seq_len] value = value[:, :, :seq_len] else: @@ -18,8 +30,6 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - dropout_rate = config.attention_dropout if training else 0.0 - input_dtype = query.dtype if input_dtype == torch.float32: query = query.to(target_dtype) @@ -30,11 +40,14 @@ def flash_attention_forward( query, key, value, - attentions_mask, + attention_mask, seq_len, - config=config, - dropout=dropout_rate, - layer_idx=layer_idx, + module.is_causal, + dropout=dropout, + softmax_scale=scaling, + sliding_window=sliding_window, + softcap=softcap, + use_top_left_mask=module._flash_attn_uses_top_left_mask, **kwargs, ) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 9c309a9ad50575..dd4287921d2c37 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,3 +1,7 @@ +from typing import Optional + +import torch + from ..utils import is_torch_greater_or_equal @@ -5,12 +9,23 @@ from torch.nn.attention.flex_attention import flex_attention -def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +): causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] def causal_mod(score, b, h, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) if causal_mask is not None: score += causal_mask[b][0][q_idx][kv_idx] return score @@ -21,8 +36,9 @@ def causal_mod(score, b, h, q_idx, kv_idx): value, score_mod=causal_mod, enable_gqa=True, - scale=module.scaling, + scale=scaling, return_lse=True, ) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 0cf58f035ea9c1..3a90ef9af2824e 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch @@ -13,7 +15,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): +def sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) @@ -31,9 +42,10 @@ def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kw key, value, attn_mask=causal_mask, - dropout_p=module.config.attention_dropout if module.training else 0.0, + dropout_p=dropout, + scale=scaling, is_causal=is_causal, - scale=module.scaling, ) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6a8b59703b9e5b..2cb4c38247d2e2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from zipfile import is_zipfile import torch diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 3eda0afde0f6c4..dc18a832420778 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -588,7 +588,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 6b19d178341fbb..c328690efc8c43 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -659,10 +659,6 @@ def __init__(self, config: Cohere2Config): [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") self.rotary_emb = Cohere2RotaryEmbedding(config=config) # Initialize weights and apply final processing diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 12b6cd81451be2..ced8b374ba04e7 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -237,7 +237,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -275,6 +275,51 @@ def forward( return outputs +GEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GemmaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma Model outputting raw hidden-states without any specific head on top.", + GEMMA_START_DOCSTRING, +) +class GemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class GemmaRotaryEmbedding(nn.Module): def __init__( self, @@ -340,51 +385,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -GEMMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`GemmaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Gemma Model outputting raw hidden-states without any specific head on top.", - GEMMA_START_DOCSTRING, -) -class GemmaPreTrainedModel(PreTrainedModel): - config_class = GemmaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index ee0829f857dc00..69972013ab2867 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -32,6 +32,7 @@ LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, + LlamaPreTrainedModel, ) from ..llama.tokenization_llama import LlamaTokenizer diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 73536bd0f8404d..3f66f2254cf0e5 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -50,6 +50,8 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "google/gemma2-7b" +_CONFIG_FOR_DOC = "Gemma2Config" class Gemma2RMSNorm(nn.Module): @@ -198,18 +200,30 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + + if module.attn_logit_softcapping is not None: + attn_weights = attn_weights / module.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * module.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -224,7 +238,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = config.query_pre_attn_scalar**-0.5 self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -238,6 +252,10 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -271,6 +289,10 @@ def forward( query_states, key_states, value_states, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, **kwargs, ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index cb7295c2adf59b..eecc2f88610984 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,13 +22,15 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import ( is_flash_attn_2_available, - is_flash_attn_greater_or_equal, is_torch_greater_or_equal, logging, ) @@ -42,15 +44,16 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, + apply_rotary_pos_emb, repeat_kv, ) if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention + pass _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -214,140 +217,86 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass -def gemma_eager_attention_forward( - config: Gemma2Config, +def eager_attention_forward( + module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor], - **_kwargs, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping + if module.attn_logit_softcapping is not None: + attn_weights = attn_weights / module.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping + attn_weights = attn_weights * module.attn_logit_softcapping if mask is not None: # no matter the length, we just slice it causal_mask = mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def gemma_flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) - - return attn_output, None - - -def gemma_flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - attn_weights = None - else: - attn_output, attn_weights = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -ALL_ATTENTION_FUNCTION = { - "gemma_flash_attention_2": gemma_flash_attention_forward, - "gemma_flex_attention": gemma_flex_attention_forward, - "gemma_eager": gemma_eager_attention_forward, - "gemma_sdpa": gemma_sdpa_attention_forward, -} +class Gemma2Attention(GemmaAttention): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + dropout = self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) -class Gemma2Attention(GemmaAttention): - pass + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights class Gemma2DecoderLayer(GemmaDecoderLayer): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c7f343b36dc75e..0f14c7140efe8d 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -282,7 +282,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0c362dd8691da7..9726d22fed37df 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -307,7 +307,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9b707a27b40943..2a4dcbefc3447e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -281,7 +281,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f8f6a884c09181..f634c52f2fc79c 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -247,7 +247,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index dcba03eafe4178..24c1c6ce3e6550 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -237,7 +237,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 4f55ed765d270f..edadaa6e1188e7 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -1,32 +1,30 @@ -from typing import Callable, List, Optional, Tuple, Union, Unpack +from typing import Callable, Optional, Tuple, Unpack import torch -from torch import nn import torch.utils.checkpoint -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import logging from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaForCausalLM, - LlamaAttention, - LlamaDecoderLayer, - eager_attention_forward, apply_rotary_pos_emb, + eager_attention_forward, ) from .configuration_qwen2 import Qwen2Config logger = logging.get_logger(__name__) -class Qwen2Attention(LlamaAttention): +class Qwen2Attention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, @@ -75,7 +73,7 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, attn_weights - + class Qwen2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() @@ -85,17 +83,22 @@ def __init__(self, config: Qwen2Config, layer_idx: int): "unexpected results may be encountered." ) + class Qwen2Model(LlamaModel): pass + class Qwen2ForCausalLM(LlamaForCausalLM): pass + class Qwen2ForSequenceClassification(LlamaForSequenceClassification): pass + class Qwen2ForTokenClassification(LlamaForTokenClassification): pass + class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): - pass \ No newline at end of file + pass diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 263d77e6ca16b2..5ca59d1257f0c0 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -294,7 +294,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states