diff --git a/docs/source/en/model_doc/xlm-roberta.md b/docs/source/en/model_doc/xlm-roberta.md index 58540015232e9d..0e08d3db3c38ac 100644 --- a/docs/source/en/model_doc/xlm-roberta.md +++ b/docs/source/en/model_doc/xlm-roberta.md @@ -55,6 +55,14 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The language from the input ids. - Uses RoBERTa tricks on the XLM approach, but does not use the translation language modeling objective. It only uses masked language modeling on sentences coming from one language. +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `FacebookAI/xlm-roberta-base` checkpoint and the Flash Attention 2 version of the model. + +
+ +
+ ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with XLM-RoBERTa. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. @@ -113,6 +121,35 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h This implementation is the same as RoBERTa. Refer to the [documentation of RoBERTa](roberta) for usage examples as well as the information relative to the inputs and outputs. +## Combining XLMRoBERTa and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModel + +>>> device = "cuda" # the device to load the model onto + +>>> tokenizer = AutoTokenizer.from_pretrained('XLM-RoBERTa-base') +>>> model = AutoModel.from_pretrained("XLM-RoBERTa-base", torch_dtype=torch.float16, attn_implementation="flash_attention_2") + +>>> text = "Replace me by any text you'd like." + +>>> encoded_input = tokenizer(text, return_tensors='pt').to(device) +>>> model.to(device) + +>>> output = model(**encoded_input) +``` + ## XLMRobertaConfig [[autodoc]] XLMRobertaConfig diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 899e5b52f002ce..4cd1f2b23aea1c 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -54,6 +54,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) +* [xlm_roberta](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. @@ -393,4 +394,4 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable outputs = model.generate(**inputs) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) -``` +``` \ No newline at end of file diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 95ea2e7dca7bd1..d81ab6ee53575a 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -19,6 +19,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -40,12 +41,18 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_xlm_roberta import XLMRobertaConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "xlm-roberta-base" @@ -62,6 +69,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->XLMRoberta class XLMRobertaEmbeddings(nn.Module): """ @@ -152,7 +172,6 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): return position_ids.unsqueeze(0).expand(input_shape) -# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta class XLMRobertaSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -165,6 +184,7 @@ def __init__(self, config, position_embedding_type=None): self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.config = config self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) @@ -180,11 +200,13 @@ def __init__(self, config, position_embedding_type=None): self.is_decoder = config.is_decoder + # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) + # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.forward with Bert->XLMRoberta def forward( self, hidden_states: torch.Tensor, @@ -302,14 +324,15 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta class XLMRobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type) self.output = XLMRobertaSelfOutput(config) self.pruned_heads = set() + self.is_causal = False + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads def prune_heads(self, heads): if len(heads) == 0: return @@ -328,6 +351,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + # Copied from transformers.models.bert.modeling_bert.BertAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -352,6 +376,220 @@ def forward( return outputs +class XLMRobertaFlashAttention2(XLMRobertaAttention): + """ + XLMRoberta flash attention module. This module inherits from `XLMRobertaAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.self.query(hidden_states) + bsz, q_len, _ = hidden_states.size() + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(bsz, -1, self.self.num_attention_heads, self.self.attention_head_size) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.self.transpose_for_scores(self.self.key(encoder_hidden_states)) + value_layer = self.self.transpose_for_scores(self.self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.self.transpose_for_scores(self.self.key(hidden_states)) + value_layer = self.self.transpose_for_scores(self.self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.self.transpose_for_scores(self.self.key(hidden_states)) + value_layer = self.self.transpose_for_scores(self.self.value(hidden_states)) + + query_layer = self.self.transpose_for_scores(mixed_query_layer) + + attn_dropout = self.self.dropout.p if self.training else 0.0 + + if self.self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_layer.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.self.config, "_pre_quantization_dtype"): + target_dtype = self.self.config._pre_quantization_dtype + else: + target_dtype = self.self.query.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) + + attn_weights = self._flash_attention_forward( + query_layer, key_layer, value_layer, attention_mask, q_len, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_weights.reshape(bsz, q_len, self.self.all_head_size) + attn_output = self.output.dense(attn_weights_reshaped) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + + if self.self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +XLM_ROBERTA_ATTENTION_CLASSES = { + "eager": XLMRobertaAttention, + "flash_attention_2": XLMRobertaFlashAttention2, +} + + # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->XLMRoberta class XLMRobertaIntermediate(nn.Module): def __init__(self, config): @@ -383,22 +621,24 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta class XLMRobertaLayer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaAttention(config) + self.attention = XLM_ROBERTA_ATTENTION_CLASSES[config._attn_implementation](config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute") + self.crossattention = XLM_ROBERTA_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type="absolute" + ) self.intermediate = XLMRobertaIntermediate(config) self.output = XLMRobertaOutput(config) + # Copied from transformers.models.bert.modeling_bert.BertLayer.forward def forward( self, hidden_states: torch.Tensor, @@ -464,6 +704,7 @@ def forward( return outputs + # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) @@ -580,7 +821,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->XLMRoberta class XLMRobertaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -590,6 +830,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): config_class = XLMRobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights @@ -680,7 +921,6 @@ def _init_weights(self, module): "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", XLM_ROBERTA_START_DOCSTRING, ) -# Copied from transformers.models.roberta.modeling_roberta.RobertaModel with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA class XLMRobertaModel(XLMRobertaPreTrainedModel): """ @@ -697,25 +937,28 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config self.embeddings = XLMRobertaEmbeddings(config) self.encoder = XLMRobertaEncoder(config) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() + # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings def get_input_embeddings(self): return self.embeddings.word_embeddings + # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value + # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base @@ -730,7 +973,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -795,7 +1037,11 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py index ca9db17270dcea..53353034de8eed 100644 --- a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py +++ b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py @@ -12,18 +12,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile +import unittest +import pytest -import unittest +from transformers import XLMRobertaConfig, is_torch_available +from transformers.testing_utils import ( + require_flash_attn, + require_sentencepiece, + require_tokenizers, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) -from transformers import is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import XLMRobertaModel + from transformers import ( + XLMRobertaForCausalLM, + XLMRobertaForMaskedLM, + XLMRobertaForMultipleChoice, + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaModel, + ) + from transformers.models.xlm_roberta.modeling_xlm_roberta import ( + XLMRobertaEmbeddings, + create_position_ids_from_input_ids, + ) @require_sentencepiece @@ -67,3 +93,574 @@ def test_xlm_roberta_large(self): self.assertEqual(output.shape, expected_output_shape) # compare the actual values for a slice of last dim self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) + + +class XLMRobertaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return XLMRobertaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = XLMRobertaModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = XLMRobertaModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = XLMRobertaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = XLMRobertaForCausalLM(config=config).to(torch_device).eval() + + # make sure that ids don't start with pad token + mask = input_ids.ne(config.pad_token_id).long() + input_ids = input_ids * mask + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + + # make sure that ids don't start with pad token + mask = next_tokens.ne(config.pad_token_id).long() + next_tokens = next_tokens * mask + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = XLMRobertaForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = XLMRobertaForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = XLMRobertaForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + + def create_and_check_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = XLMRobertaForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class XLMRobertaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + XLMRobertaForCausalLM, + XLMRobertaForMaskedLM, + XLMRobertaModel, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaForMultipleChoice, + XLMRobertaForQuestionAnswering, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (XLMRobertaForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": XLMRobertaModel, + "fill-mask": XLMRobertaForMaskedLM, + "question-answering": XLMRobertaForQuestionAnswering, + "text-classification": XLMRobertaForSequenceClassification, + "text-generation": XLMRobertaForCausalLM, + "token-classification": XLMRobertaForTokenClassification, + "zero-shot": XLMRobertaForSequenceClassification, + } + if is_torch_available() + else {} + ) + + def setUp(self): + self.model_tester = XLMRobertaModelTester(self) + self.config_tester = ConfigTester(self, config_class=XLMRobertaConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_create_position_ids_respects_padding_index(self): + """Ensure that the default position ids only assign a sequential . This is a regression + test for https://github.com/huggingface/transformers/issues/1761 + + The position ids should be masked with the embedding object's padding index. Therefore, the + first available non-padding position index is XLMRobertaEmbeddings.padding_idx + 1 + """ + config = self.model_tester.prepare_config_and_inputs()[0] + model = XLMRobertaEmbeddings(config=config) + + input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]]) + expected_positions = torch.as_tensor( + [[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]] + ) + + position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx) + self.assertEqual(position_ids.shape, expected_positions.shape) + self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + + def test_create_position_ids_from_inputs_embeds(self): + """Ensure that the default position ids only assign a sequential . This is a regression + test for https://github.com/huggingface/transformers/issues/1761 + + The position ids should be masked with the embedding object's padding index. Therefore, the + first available non-padding position index is XLMRobertaEmbeddings.padding_idx + 1 + """ + config = self.model_tester.prepare_config_and_inputs()[0] + embeddings = XLMRobertaEmbeddings(config=config) + + inputs_embeds = torch.empty(2, 4, 30) + expected_single_positions = [ + 0 + embeddings.padding_idx + 1, + 1 + embeddings.padding_idx + 1, + 2 + embeddings.padding_idx + 1, + 3 + embeddings.padding_idx + 1, + ] + expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions]) + position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds) + self.assertEqual(position_ids.shape, expected_positions.shape) + self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + + # Because XLMRobertaForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + + # Because XLMRobertaForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.bfloat16, + ) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))