Skip to content

Commit

Permalink
Corrections to XLM_Roberta flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidAfonsoValente committed Jan 30, 2024
1 parent 09f4695 commit 9cff1b0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
11 changes: 7 additions & 4 deletions src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,11 @@ def forward(
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.self.query(hidden_states)
bsz, q_len, _ = hidden_states.size()

def reshape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(bsz, -1, self.self.num_attention_heads, self.self.attention_head_size)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
Expand All @@ -427,8 +432,6 @@ def forward(
key_layer = self.self.transpose_for_scores(self.self.key(hidden_states))
value_layer = self.self.transpose_for_scores(self.self.value(hidden_states))

bsz, q_len, _ = hidden_states.size()

query_layer = self.self.transpose_for_scores(mixed_query_layer)

if self.self.is_decoder:
Expand Down Expand Up @@ -628,7 +631,7 @@ def __init__(self, config):
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = XLM_ROBERTA_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type="absolute"
config, position_embedding_type="absolute", is_cross_attention=True
)
self.intermediate = XLMRobertaIntermediate(config)
self.output = XLMRobertaOutput(config)
Expand Down Expand Up @@ -1032,7 +1035,7 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if attention_mask is None:
if self.use_flash_attention_2:
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
Expand Down
21 changes: 10 additions & 11 deletions tests/models/xlm_roberta/test_modeling_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ def create_and_check_xlm_roberta_model(
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

def create_and_check_xlm_roberta_for_masked_lm(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XLMRobertaForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_xlm_roberta_for_causal_lm(
self,
config,
input_ids,
Expand All @@ -195,20 +204,10 @@ def create_and_check_xlm_roberta_for_masked_lm(
encoder_hidden_states,
encoder_attention_mask,
):
model = XLMRobertaForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_xlm_roberta_for_causal_lm(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = XLMRobertaForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_xlm_roberta_for_question_answering(
Expand Down

0 comments on commit 9cff1b0

Please sign in to comment.