diff --git a/docs/source/en/model_doc/xlm-roberta.md b/docs/source/en/model_doc/xlm-roberta.md
index 58540015232e9d..0e08d3db3c38ac 100644
--- a/docs/source/en/model_doc/xlm-roberta.md
+++ b/docs/source/en/model_doc/xlm-roberta.md
@@ -55,6 +55,14 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The
language from the input ids.
- Uses RoBERTa tricks on the XLM approach, but does not use the translation language modeling objective. It only uses masked language modeling on sentences coming from one language.
+### Expected speedups
+
+Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `FacebookAI/xlm-roberta-base` checkpoint and the Flash Attention 2 version of the model.
+
+
+
+
+
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with XLM-RoBERTa. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
@@ -113,6 +121,35 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
This implementation is the same as RoBERTa. Refer to the [documentation of RoBERTa](roberta) for usage examples as well as the information relative to the inputs and outputs.
+## Combining XLMRoBERTa and Flash Attention 2
+
+First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
+
+To load and run a model using Flash Attention 2, refer to the snippet below:
+
+```python
+>>> import torch
+>>> from transformers import AutoTokenizer, AutoModel
+
+>>> device = "cuda" # the device to load the model onto
+
+>>> tokenizer = AutoTokenizer.from_pretrained('XLM-RoBERTa-base')
+>>> model = AutoModel.from_pretrained("XLM-RoBERTa-base", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
+
+>>> text = "Replace me by any text you'd like."
+
+>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
+>>> model.to(device)
+
+>>> output = model(**encoded_input)
+```
+
## XLMRobertaConfig
[[autodoc]] XLMRobertaConfig
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 899e5b52f002ce..4cd1f2b23aea1c 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -54,6 +54,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
+* [xlm_roberta](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
@@ -393,4 +394,4 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
-```
+```
\ No newline at end of file
diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
index 95ea2e7dca7bd1..d81ab6ee53575a 100644
--- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
@@ -19,6 +19,7 @@
from typing import List, Optional, Tuple, Union
import torch
+import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -40,12 +41,18 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_xlm_roberta import XLMRobertaConfig
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "xlm-roberta-base"
@@ -62,6 +69,19 @@
]
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->XLMRoberta
class XLMRobertaEmbeddings(nn.Module):
"""
@@ -152,7 +172,6 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
return position_ids.unsqueeze(0).expand(input_shape)
-# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta
class XLMRobertaSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
@@ -165,6 +184,7 @@ def __init__(self, config, position_embedding_type=None):
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.config = config
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
@@ -180,11 +200,13 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.forward with Bert->XLMRoberta
def forward(
self,
hidden_states: torch.Tensor,
@@ -302,14 +324,15 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states
-# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta
class XLMRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = XLMRobertaSelfOutput(config)
self.pruned_heads = set()
+ self.is_causal = False
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads):
if len(heads) == 0:
return
@@ -328,6 +351,7 @@ def prune_heads(self, heads):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -352,6 +376,220 @@ def forward(
return outputs
+class XLMRobertaFlashAttention2(XLMRobertaAttention):
+ """
+ XLMRoberta flash attention module. This module inherits from `XLMRobertaAttention` as the weights of the module
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
+ API of flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.self.query(hidden_states)
+ bsz, q_len, _ = hidden_states.size()
+
+ def reshape(x: torch.Tensor) -> torch.Tensor:
+ """separate heads"""
+ return x.view(bsz, -1, self.self.num_attention_heads, self.self.attention_head_size)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.self.transpose_for_scores(self.self.key(encoder_hidden_states))
+ value_layer = self.self.transpose_for_scores(self.self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.self.transpose_for_scores(self.self.key(hidden_states))
+ value_layer = self.self.transpose_for_scores(self.self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.self.transpose_for_scores(self.self.key(hidden_states))
+ value_layer = self.self.transpose_for_scores(self.self.value(hidden_states))
+
+ query_layer = self.self.transpose_for_scores(mixed_query_layer)
+
+ attn_dropout = self.self.dropout.p if self.training else 0.0
+
+ if self.self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query_layer.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.self.config, "_pre_quantization_dtype"):
+ target_dtype = self.self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.self.query.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_layer = query_layer.to(target_dtype)
+ key_layer = key_layer.to(target_dtype)
+ value_layer = value_layer.to(target_dtype)
+
+ attn_weights = self._flash_attention_forward(
+ query_layer, key_layer, value_layer, attention_mask, q_len, dropout=attn_dropout
+ )
+
+ attn_weights_reshaped = attn_weights.reshape(bsz, q_len, self.self.all_head_size)
+ attn_output = self.output.dense(attn_weights_reshaped)
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+
+ if self.self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+XLM_ROBERTA_ATTENTION_CLASSES = {
+ "eager": XLMRobertaAttention,
+ "flash_attention_2": XLMRobertaFlashAttention2,
+}
+
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->XLMRoberta
class XLMRobertaIntermediate(nn.Module):
def __init__(self, config):
@@ -383,22 +621,24 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states
-# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta
class XLMRobertaLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
- self.attention = XLMRobertaAttention(config)
+ self.attention = XLM_ROBERTA_ATTENTION_CLASSES[config._attn_implementation](config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute")
+ self.crossattention = XLM_ROBERTA_ATTENTION_CLASSES[config._attn_implementation](
+ config, position_embedding_type="absolute"
+ )
self.intermediate = XLMRobertaIntermediate(config)
self.output = XLMRobertaOutput(config)
+ # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -464,6 +704,7 @@ def forward(
return outputs
+ # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
@@ -580,7 +821,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output
-# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->XLMRoberta
class XLMRobertaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -590,6 +830,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class = XLMRobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
_no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
@@ -680,7 +921,6 @@ def _init_weights(self, module):
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
XLM_ROBERTA_START_DOCSTRING,
)
-# Copied from transformers.models.roberta.modeling_roberta.RobertaModel with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaModel(XLMRobertaPreTrainedModel):
"""
@@ -697,25 +937,28 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
"""
- # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = XLMRobertaEmbeddings(config)
self.encoder = XLMRobertaEncoder(config)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
+ # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings.word_embeddings
+ # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
+ # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
@@ -730,7 +973,6 @@ class PreTrainedModel
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
- # Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
@@ -795,7 +1037,11 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
- attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
diff --git a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py
index ca9db17270dcea..53353034de8eed 100644
--- a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py
+++ b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py
@@ -12,18 +12,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import tempfile
+import unittest
+import pytest
-import unittest
+from transformers import XLMRobertaConfig, is_torch_available
+from transformers.testing_utils import (
+ require_flash_attn,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
-from transformers import is_torch_available
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
- from transformers import XLMRobertaModel
+ from transformers import (
+ XLMRobertaForCausalLM,
+ XLMRobertaForMaskedLM,
+ XLMRobertaForMultipleChoice,
+ XLMRobertaForQuestionAnswering,
+ XLMRobertaForSequenceClassification,
+ XLMRobertaForTokenClassification,
+ XLMRobertaModel,
+ )
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import (
+ XLMRobertaEmbeddings,
+ create_position_ids_from_input_ids,
+ )
@require_sentencepiece
@@ -67,3 +93,574 @@ def test_xlm_roberta_large(self):
self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
+
+
+class XLMRobertaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ return XLMRobertaConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ def create_and_check_model(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = XLMRobertaModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
+ result = model(input_ids, token_type_ids=token_type_ids)
+ result = model(input_ids)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def create_and_check_model_as_decoder(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ config.add_cross_attention = True
+ model = XLMRobertaModel(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def create_and_check_for_causal_lm(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ model = XLMRobertaForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ config.is_decoder = True
+ config.add_cross_attention = True
+ model = XLMRobertaForCausalLM(config=config).to(torch_device).eval()
+
+ # make sure that ids don't start with pad token
+ mask = input_ids.ne(config.pad_token_id).long()
+ input_ids = input_ids * mask
+
+ # first forward pass
+ outputs = model(
+ input_ids,
+ attention_mask=input_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=True,
+ )
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+
+ # make sure that ids don't start with pad token
+ mask = next_tokens.ne(config.pad_token_id).long()
+ next_tokens = next_tokens * mask
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(
+ next_input_ids,
+ attention_mask=next_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_for_masked_lm(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = XLMRobertaForMaskedLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_for_token_classification(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_labels = self.num_labels
+ model = XLMRobertaForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
+ def create_and_check_for_multiple_choice(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_choices = self.num_choices
+ model = XLMRobertaForMultipleChoice(config=config)
+ model.to(torch_device)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_input_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ labels=choice_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
+
+ def create_and_check_for_question_answering(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = XLMRobertaForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class XLMRobertaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ XLMRobertaForCausalLM,
+ XLMRobertaForMaskedLM,
+ XLMRobertaModel,
+ XLMRobertaForSequenceClassification,
+ XLMRobertaForTokenClassification,
+ XLMRobertaForMultipleChoice,
+ XLMRobertaForQuestionAnswering,
+ )
+ if is_torch_available()
+ else ()
+ )
+ all_generative_model_classes = (XLMRobertaForCausalLM,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": XLMRobertaModel,
+ "fill-mask": XLMRobertaForMaskedLM,
+ "question-answering": XLMRobertaForQuestionAnswering,
+ "text-classification": XLMRobertaForSequenceClassification,
+ "text-generation": XLMRobertaForCausalLM,
+ "token-classification": XLMRobertaForTokenClassification,
+ "zero-shot": XLMRobertaForSequenceClassification,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ def setUp(self):
+ self.model_tester = XLMRobertaModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=XLMRobertaConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_various_embeddings(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for type in ["absolute", "relative_key", "relative_key_query"]:
+ config_and_inputs[0].position_embedding_type = type
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_as_decoder(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
+
+ def test_model_as_decoder_with_default_input_mask(self):
+ # This regression test was failing with PyTorch < 1.3
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ) = self.model_tester.prepare_config_and_inputs_for_decoder()
+
+ input_mask = None
+
+ self.model_tester.create_and_check_model_as_decoder(
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ def test_for_causal_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_decoder_model_past_with_large_inputs_relative_pos_emb(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ config_and_inputs[0].position_embedding_type = "relative_key"
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_for_masked_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ def test_create_position_ids_respects_padding_index(self):
+ """Ensure that the default position ids only assign a sequential . This is a regression
+ test for https://github.com/huggingface/transformers/issues/1761
+
+ The position ids should be masked with the embedding object's padding index. Therefore, the
+ first available non-padding position index is XLMRobertaEmbeddings.padding_idx + 1
+ """
+ config = self.model_tester.prepare_config_and_inputs()[0]
+ model = XLMRobertaEmbeddings(config=config)
+
+ input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
+ expected_positions = torch.as_tensor(
+ [[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
+ )
+
+ position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
+ self.assertEqual(position_ids.shape, expected_positions.shape)
+ self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
+
+ def test_create_position_ids_from_inputs_embeds(self):
+ """Ensure that the default position ids only assign a sequential . This is a regression
+ test for https://github.com/huggingface/transformers/issues/1761
+
+ The position ids should be masked with the embedding object's padding index. Therefore, the
+ first available non-padding position index is XLMRobertaEmbeddings.padding_idx + 1
+ """
+ config = self.model_tester.prepare_config_and_inputs()[0]
+ embeddings = XLMRobertaEmbeddings(config=config)
+
+ inputs_embeds = torch.empty(2, 4, 30)
+ expected_single_positions = [
+ 0 + embeddings.padding_idx + 1,
+ 1 + embeddings.padding_idx + 1,
+ 2 + embeddings.padding_idx + 1,
+ 3 + embeddings.padding_idx + 1,
+ ]
+ expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
+ position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
+ self.assertEqual(position_ids.shape, expected_positions.shape)
+ self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
+
+ # Because XLMRobertaForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
+
+ # Because XLMRobertaForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference_padding_right(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))