diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 1449304cb93451..e19488bbb47447 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -24,10 +24,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -795,7 +791,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -867,43 +862,12 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) - use_sdpa_attention_masks = ( - self._use_sdpa - and self.position_embedding_type == "absolute" - and head_mask is None - and not output_attentions - ) - - # Expand the attention mask - if use_sdpa_attention_masks: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if self.config.is_decoder: - extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - embedding_output, - past_key_values_length, - ) - else: - extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -913,14 +877,7 @@ def forward( if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if use_sdpa_attention_masks: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -931,6 +888,13 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask,