From 03bf755485e4424ab75617a30132ea58dab33eaa Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:26:00 +0800 Subject: [PATCH] [RoBERTa-based] Add support for sdpa (#30510) * Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/perf_infer_gpu_one.md | 27 ++- .../bridgetower/modeling_bridgetower.py | 2 +- .../models/camembert/modeling_camembert.py | 177 +++++++++++++-- .../models/roberta/modeling_roberta.py | 189 ++++++++++++++-- .../modeling_roberta_prelayernorm.py | 1 - .../xlm_roberta/modeling_xlm_roberta.py | 188 ++++++++++++++-- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 201 ++++++++++++++++-- .../camembert/test_modeling_camembert.py | 32 ++- .../xlm_roberta/test_modeling_xlm_roberta.py | 27 ++- .../test_modeling_xlm_roberta_xl.py | 81 ++++++- utils/check_support_list.py | 3 +- 11 files changed, 828 insertions(+), 100 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index a84d8da8a54d3d..3517d93bfc8a29 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -74,7 +74,6 @@ FlashAttention-2 is currently supported for the following architectures: * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) -* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) @@ -86,6 +85,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) @@ -204,9 +204,11 @@ For now, Transformers supports SDPA inference and training for the following arc * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) +* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) +* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) @@ -216,10 +218,16 @@ For now, Transformers supports SDPA inference and training for the following arc * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) +* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) +* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) +* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) +* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) @@ -233,6 +241,14 @@ For now, Transformers supports SDPA inference and training for the following arc * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) * [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) +* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) +* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) +* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) +* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) +* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) +* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) * [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) @@ -243,12 +259,9 @@ For now, Transformers supports SDPA inference and training for the following arc * [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) * [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) -* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) -* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) -* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) -* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) -* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) -* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) +* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) +* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 91cbda9b72edbb..81785e147db956 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1063,7 +1063,7 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index f050b56d11786c..03a60a2856be2e 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -294,6 +300,108 @@ def forward( return outputs +# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert +class CamembertSdpaSelfAttention(CamembertSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from CamembertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert class CamembertSelfOutput(nn.Module): def __init__(self, config): @@ -311,6 +419,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to CAMEMBERT_SELF_ATTENTION_CLASSES = { "eager": CamembertSelfAttention, + "sdpa": CamembertSdpaSelfAttention, } @@ -603,6 +712,7 @@ class CamembertPreTrainedModel(PreTrainedModel): config_class = CamembertConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -749,7 +859,7 @@ class CamembertModel(CamembertPreTrainedModel): _no_split_modules = [] - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -759,6 +869,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = CamembertPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -782,7 +895,7 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -846,9 +959,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -857,9 +967,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -868,7 +1012,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -879,13 +1031,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index c15636f4b65ccb..f1f83147527dfb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -276,6 +282,108 @@ def forward( return outputs +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from RobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.bert.modeling_bert.BertSelfOutput class RobertaSelfOutput(nn.Module): def __init__(self, config): @@ -293,6 +401,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to ROBERTA_SELF_ATTENTION_CLASSES = { "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, } @@ -585,7 +694,8 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True - _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention"] + _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -676,23 +786,22 @@ def _init_weights(self, module): "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", ROBERTA_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA class RobertaModel(RobertaPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - - .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 - """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -702,6 +811,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = RobertaPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -725,7 +837,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -789,9 +900,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -800,9 +908,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -811,7 +953,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -822,13 +972,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 3e592f38776801..95657c260dc7a4 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -569,7 +569,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 5fa4812d350dec..3ac94e75f92fe4 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -277,6 +283,108 @@ def forward( return outputs +# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta +class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from XLMRobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta class XLMRobertaSelfOutput(nn.Module): def __init__(self, config): @@ -294,6 +402,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to XLM_ROBERTA_SELF_ATTENTION_CLASSES = { "eager": XLMRobertaSelfAttention, + "sdpa": XLMRobertaSdpaSelfAttention, } @@ -587,7 +696,8 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): config_class = XLMRobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True - _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"] + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -682,19 +792,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - - .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 - """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -704,6 +812,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -727,7 +838,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -791,9 +901,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -802,9 +909,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -813,7 +954,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -824,13 +973,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 86fdadeaef0259..f66a32291794f1 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -19,10 +19,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -39,6 +44,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -274,6 +280,108 @@ def forward( return outputs +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL +class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from XLMRobertaXLSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "XLMRobertaXLSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + class XLMRobertaXLSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -287,11 +395,19 @@ def forward(self, hidden_states, input_tensor): return hidden_states +XLMROBERTAXL_SELF_ATTENTION_CLASSES = { + "eager": XLMRobertaXLSelfAttention, + "sdpa": XLMRobertaXLSdpaSelfAttention, +} + + class XLMRobertaXLAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self = XLMRobertaXLSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = XLMROBERTAXL_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = XLMRobertaXLSelfOutput(config) self.pruned_heads = set() @@ -573,6 +689,7 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): config_class = XLMRobertaXLConfig base_model_prefix = "roberta" _no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -651,18 +768,22 @@ def _init_weights(self, module): "The bare XLM-RoBERTa-XL Model transformer outputting raw hidden-states without any specific head on top.", XLM_ROBERTA_XL_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->XLMRobertaXL, BERT->XLM_ROBERTA_XL class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` - argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with - both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as - an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL + _no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -672,6 +793,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaXLPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -695,7 +819,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -759,9 +882,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -770,9 +890,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -781,7 +935,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -792,13 +954,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/tests/models/camembert/test_modeling_camembert.py b/tests/models/camembert/test_modeling_camembert.py index f2fba59496da4f..f779c3a80909b9 100644 --- a/tests/models/camembert/test_modeling_camembert.py +++ b/tests/models/camembert/test_modeling_camembert.py @@ -16,7 +16,14 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) if is_torch_available(): @@ -31,7 +38,7 @@ class CamembertModelIntegrationTest(unittest.TestCase): @slow def test_output_embeds_base_model(self): - model = CamembertModel.from_pretrained("almanach/camembert-base") + model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="eager") model.to(torch_device) input_ids = torch.tensor( @@ -54,3 +61,24 @@ def test_output_embeds_base_model(self): # expected_slice = roberta.model.forward(input_ids)[0][:, :3, :3].detach() self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @slow + @require_torch_sdpa + def test_output_embeds_base_model_sdpa(self): + input_ids = torch.tensor( + [[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], + device=torch_device, + dtype=torch.long, + ) # J'aime le camembert ! + + expected_slice = torch.tensor( + [[[-0.0254, 0.0235, 0.1027], [0.0606, -0.1811, -0.0418], [-0.1561, -0.1127, 0.2687]]], + device=torch_device, + dtype=torch.float, + ) + + model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="sdpa").to(torch_device) + with torch.no_grad(): + output = model(input_ids)["last_hidden_state"].detach() + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py index d9b69bb9ab5f2a..f8ec1f5b7671a5 100644 --- a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py +++ b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py @@ -17,7 +17,13 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_torch, + require_torch_sdpa, + slow, +) if is_torch_available(): @@ -32,7 +38,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase): @slow def test_xlm_roberta_base(self): - model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base") + model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="eager") input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) # The dog is cute and lives in the garden house @@ -49,6 +55,23 @@ def test_xlm_roberta_base(self): # compare the actual values for a slice of last dim self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) + @require_torch_sdpa + def test_xlm_roberta_base_sdpa(self): + input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) + # The dog is cute and lives in the garden house + + expected_output_shape = torch.Size((1, 12, 768)) # batch_size, sequence_length, embedding_vector_dim + expected_output_values_last_dim = torch.tensor( + [[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]] + ) + + model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="sdpa") + with torch.no_grad(): + output = model(input_ids)["last_hidden_state"].detach() + self.assertEqual(output.shape, expected_output_shape) + # compare the actual values for a slice of last dim + self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) + @slow def test_xlm_roberta_large(self): model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-large") diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index 22663db27c811c..a73f5618ff7eda 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -14,10 +14,11 @@ # limitations under the License. +import tempfile import unittest from transformers import XLMRobertaXLConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -515,6 +516,84 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + # TODO: Remove this and use the parent method (in common tests) once XLM RoBERTa XL supports low_cpu_mem_usage=True. + @require_torch_sdpa + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate + def test_eager_matches_sdpa_generate(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + # Ignore copy + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=False, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + # Ignore copy + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=False, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class XLMRobertaModelXLIntegrationTest(unittest.TestCase): diff --git a/utils/check_support_list.py b/utils/check_support_list.py index 89e1bcf9d6a4c0..55d93611f4ce5f 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -70,6 +70,7 @@ def check_sdpa_support_list(): "For now, Transformers supports SDPA inference and training for the following architectures:" )[1] doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0] + doctext = doctext.lower() patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py")) patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py")) @@ -85,7 +86,7 @@ def check_sdpa_support_list(): archs_supporting_sdpa.append(model_name) for arch in archs_supporting_sdpa: - if arch not in doctext and arch not in doctext.replace("-", "_"): + if not any(term in doctext for term in [arch, arch.replace("_", "-"), arch.replace("_", " ")]): raise ValueError( f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." )