Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added sdpa support for roberta #32391

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/en/model_doc/roberta.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
* Byte-level BPE vocabulary: Uses BPE with bytes as a subunit instead of characters, accommodating Unicode characters.
- [CamemBERT](camembert) is a wrapper around RoBERTa. Refer to its model page for usage examples.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import RobertaModel

model = RobertaModel.from_pretrained("roberta-base", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with RoBERTa. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [Roberta](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)


<Tip>
Expand Down
91 changes: 91 additions & 0 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

Expand All @@ -40,9 +41,11 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
logging,
replace_return_docstrings,
)

from .configuration_roberta import RobertaConfig


Expand Down Expand Up @@ -275,6 +278,92 @@ def forward(
outputs = outputs + (past_key_value,)
return outputs

class RobertaSdpaSelfAttention(RobertaSelfAttention):

def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

bsz, tgt_len, _ = hidden_states.size()

is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class RobertaSelfOutput(nn.Module):
Expand All @@ -293,6 +382,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to

ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": RobertaSelfAttention,
"sdpa": RobertaSdpaSelfAttention,
}


Expand Down Expand Up @@ -586,6 +676,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention"]
_supports_sdpa = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down