Skip to content

Commit

Permalink
add sdpa mbart (#32033)
Browse files Browse the repository at this point in the history
* add sdpa mbart

useful for donut

* update sdpa docs

* formatting

* add self._use_sdpa in mbartencoder

* use self.config to check attn

* retrigger checks

* [run-slow] mbart
  • Loading branch information
nbroad1881 authored Sep 7, 2024
1 parent a70286f commit 66bc4de
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
Expand Down
151 changes: 144 additions & 7 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -405,8 +410,116 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart
class MBartSdpaAttention(MBartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

query_states = self._shape(query_states, tgt_len, bsz)

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False

# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


MBART_ATTENTION_CLASSES = {
"eager": MBartAttention,
"sdpa": MBartSdpaAttention,
"flash_attention_2": MBartFlashAttention2,
}

Expand Down Expand Up @@ -632,6 +745,7 @@ class MBartPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -841,7 +955,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.config = config
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)

Expand Down Expand Up @@ -929,9 +1043,13 @@ def forward(

# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
Expand Down Expand Up @@ -1021,7 +1139,8 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.config = config

self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)

Expand Down Expand Up @@ -1141,9 +1260,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
Expand All @@ -1152,8 +1280,17 @@ def forward(

# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
Expand Down

0 comments on commit 66bc4de

Please sign in to comment.