Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa for Beit #34941

Merged
merged 10 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/source/en/model_doc/beit.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,43 @@ alt="drawing" width="600"/>

<small> BEiT pre-training. Taken from the <a href="https://arxiv.org/abs/2106.08254">original paper.</a> </small>

### Using Scaled Dot Product Attention (SDPA)
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import BeitForImageClassification
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04) with `float16` and
`microsoft/beit-base-patch16-224` model, we saw the following improvements during training and inference:

#### Training

| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------|
| 50 | 2 | (1048, 640) | True | 0.984 | 0.746 | 31.975 | 6738.915 | 4319.886 | 55.998 |

#### Inference

| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved (%) |
|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|----------------------:|
| 1 | 0.012 | ±0.3% | 3.76657e+08 | 0.011 | ±0.5% | 3.75739e+08 | 1.05 | 0.244 |
| 4 | 0.013 | ±0.1% | 4.03147e+08 | 0.011 | ±0.2% | 3.90554e+08 | 1.178 | 3.225 |
| 16 | 0.045 | ±0.1% | 4.96697e+08 | 0.035 | ±0.1% | 4.51232e+08 | 1.304 | 10.076 |
| 32 | 0.088 | ±0.1% | 6.24417e+08 | 0.066 | ±0.1% | 5.33488e+08 | 1.325 | 17.044 |

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BEiT.
Expand Down
40 changes: 40 additions & 0 deletions docs/source/en/model_doc/data2vec.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,46 @@ The original code for vision can be found [here](https://github.com/facebookrese
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

The SDPA implementation is currently available for the Data2VecAudio and Data2VecVision models.

```
from transformers import Data2VecVisionForImageClassification
model = Data2VecVisionForImageClassification.from_pretrained("facebook/data2vec-vision-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

For the Data2VecVision model, on a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04)
with `float16` and `facebook/data2vec-vision-base` model, we saw the following improvements during training and
inference:

#### Training

| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------|
| 50 | 2 | (1048, 640) | True | 0.996 | 0.754 | 32.147 | 6722.198 | 4264.653 | 57.626 |

#### Inference

| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved |
|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|--------------------:|
| 1 | 0.011 | ±0.3% | 3.76143e+08 | 0.01 | ±0.3% | 3.74397e+08 | 1.101 | 0.466 |
| 4 | 0.014 | ±0.1% | 4.02756e+08 | 0.012 | ±0.2% | 3.91373e+08 | 1.219 | 2.909 |
| 16 | 0.046 | ±0.3% | 4.96482e+08 | 0.035 | ±0.2% | 4.51017e+08 | 1.314 | 10.081 |
| 32 | 0.088 | ±0.1% | 6.23903e+08 | 0.067 | ±0.1% | 5.32974e+08 | 1.33 | 17.061 |

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Data2Vec.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [BioGpt](https://huggingface.co/docs/transformers/model_doc/biogpt#transformers.BioGptModel)
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
Expand All @@ -226,6 +227,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
Expand Down
71 changes: 70 additions & 1 deletion src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,68 @@ def forward(
return outputs


class BeitSdpaSelfAttention(BeitSelfAttention):
qubvel marked this conversation as resolved.
Show resolved Hide resolved
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
qubvel marked this conversation as resolved.
Show resolved Hide resolved
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)

mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

attn_bias = None
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attn_bias = self.relative_position_bias(
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
if attn_bias is None:
attn_bias = relative_position_bias
else:
attn_bias += relative_position_bias

scaling = 1 / math.sqrt(self.attention_head_size)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attn_bias,
dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=scaling,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, None


class BeitSelfOutput(nn.Module):
"""
The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
Expand All @@ -379,10 +441,16 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma
return hidden_states


BEIT_SELF_ATTENTION_CLASSES = {
"eager": BeitSelfAttention,
"sdpa": BeitSdpaSelfAttention,
}


class BeitAttention(nn.Module):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.attention = BeitSelfAttention(config, window_size=window_size)
self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size)
self.output = BeitSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -700,6 +768,7 @@ class BeitPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
76 changes: 74 additions & 2 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,69 @@ def forward(
return outputs


# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)

mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

attn_bias = None
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attn_bias = self.relative_position_bias(
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
if attn_bias is None:
attn_bias = relative_position_bias
else:
attn_bias += relative_position_bias

scaling = 1 / math.sqrt(self.attention_head_size)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attn_bias,
dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=scaling,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, None


# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
class Data2VecVisionSelfOutput(nn.Module):
"""
Expand All @@ -381,11 +444,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma
return hidden_states


# Copied from transformers.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision
DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
"eager": Data2VecVisionSelfAttention,
"sdpa": Data2VecVisionSdpaSelfAttention,
}


# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
class Data2VecVisionAttention(nn.Module):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.attention = Data2VecVisionSelfAttention(config, window_size=window_size)
self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, window_size=window_size
)
self.output = Data2VecVisionSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -711,6 +782,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecVisionLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Loading