Skip to content

Commit

Permalink
Fix contrastive search to correctly handle input with padding (huggin…
Browse files Browse the repository at this point in the history
…gface#33507)

* fix: handle padding in contrastive search for decoder-only models

* fix: handle padding in contrastive search for encoder-decoder models

* tests: move padding contrastive test to test_util, add t5 test

* fix: handle if model_kwargs["decoder_attention_mask"] is None

* refactor: improve padding input contrastive search generation tests

* chore: _ranking_fast to use LongTensor for cosine_matrix_mask
  • Loading branch information
ducviet00 authored Sep 20, 2024
1 parent c0c6815 commit dc8b6ea
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,15 @@ def _contrastive_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

# Create cosine_matrix_mask based on the attention_mask
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
if self.config.is_encoder_decoder:
if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
else:
cosine_matrix_mask = model_kwargs["attention_mask"]
cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)

this_peer_finished = False

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
Expand Down Expand Up @@ -2771,7 +2780,12 @@ def _contrastive_search(
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
# introduce (noticeable) slowdowns on single-device runs.
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = _ranking_fast(
context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k
)
cosine_matrix_mask = torch.cat(
[cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1
)
selected_idx = selected_idx.to("cpu")

# This will be used instead of the previous inneficient torch.stack(torch.split())
Expand Down Expand Up @@ -4283,6 +4297,7 @@ def _ranking_fast(
context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
cosine_matrix_mask: torch.LongTensor,
alpha: float,
beam_width: int,
) -> torch.FloatTensor:
Expand All @@ -4294,6 +4309,13 @@ def _ranking_fast(
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]

# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
# Using a large negative value for masked positions
cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min
cosine_matrix = cosine_matrix + cosine_matrix_mask

degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
Expand Down
135 changes: 135 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

if is_torch_available():
import torch
import torch.nn.functional as F

from transformers import (
AutoModelForCausalLM,
Expand All @@ -59,6 +60,7 @@
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
Expand Down Expand Up @@ -3644,6 +3646,139 @@ def test_init_static_cache_multi_gpu(self):
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))

@slow
def test_padding_input_contrastive_search_gpt2(self):
# Load the pre-trained GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)

# Set the tokenizer to left-pad the sequences
tokenizer.padding_side = "left"

# Define the PAD token as the EOS token
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Define the input prompt
prompt_text = "The whispered legends of the haunted mansion spoke"

# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)

# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4

# Define the padding length to add to the input IDs and attention mask
padding_length = 10

# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Pad the input IDs and attention mask on the left
padded_input_ids = F.pad(
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)

# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=padded_input_ids,
attention_mask=padded_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)

# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(
generated_text_with_padding,
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
)

@slow
def test_padding_input_contrastive_search_t5(self):
# Load the pre-trained T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)

# Define the input prompt
prompt_text = "translate English to German: I need to finish this task before the end of the day."

# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)

# Define the decoder prompt
decoder_prompt_text = "Ich muss diese Aufgabe"
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)

# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4

# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Define the padding length to add to the input IDs and attention mask
padding_length = 10

# Pad the decoder input IDs and attention mask on the left
padded_decoder_input_ids = F.pad(
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
# Since the decoder_start_token_id is the same as the pad_token_id,
# the last padded token represents the decoder start token.
# Set the attention mask for the decoder_start_token_id to True (1).
padded_decoder_attention_mask[:, padding_length - 1] = 1
# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=padded_decoder_input_ids,
decoder_attention_mask=padded_decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)

# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down

0 comments on commit dc8b6ea

Please sign in to comment.