diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 82bafc6b429561..dd3433f2cd4862 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -239,6 +239,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) +* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 6cad7b08f994a4..9455f21b2073ff 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,7 +24,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -405,8 +410,116 @@ def forward( return attn_output, attn_weights, past_key_value +# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart +class MBartSdpaAttention(MBartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + MBART_ATTENTION_CLASSES = { "eager": MBartAttention, + "sdpa": MBartSdpaAttention, "flash_attention_2": MBartFlashAttention2, } @@ -632,6 +745,7 @@ class MBartPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std @@ -841,7 +955,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N embed_dim, ) self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.config = config self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) @@ -929,9 +1043,13 @@ def forward( # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -1021,7 +1139,8 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.d_model, ) self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.config = config + self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1141,9 +1260,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -1152,8 +1280,17 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask(