Skip to content

Commit

Permalink
Adding Prompt lookup decoding (#27775)
Browse files Browse the repository at this point in the history
* MVP

* fix ci

* more ci

* remove redundant kwarg

* added and wired up PromptLookupCandidateGenerator

* rebased with main, working

* removed print

* style fixes

* fix test

* fixed tests

* added test for prompt lookup decoding

* fixed circleci

* fixed test issue

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/candidate_generator.py

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
4 people authored Jan 13, 2024
1 parent 29a2b14 commit e304f97
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 9 deletions.
92 changes: 92 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,98 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)


class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
likely continuations in the provided prompt (input_ids) itself.
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
Args:
max_matching_ngram_size (`int`):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
"""

def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = 2,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
input_length = input_ids.size(1)

chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

# Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:]

# Find where the windows match the ngram
matches = (windows == ngram_tensor).all(dim=2)

# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]

# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)

if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break

if chosen_ids is None or len(chosen_ids) == 0:
# Need to make a dummy tensor to avoid errors
chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device)

# Now need extend input_ids with chosen_ids
chosen_ids = chosen_ids.unsqueeze(0)
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
# assisted_generation expects logits as well, but we don't have those here, so returning None
return candidate_input_ids, None

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return


def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down
24 changes: 15 additions & 9 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
Expand Down Expand Up @@ -908,14 +909,19 @@ def _get_candidate_generator(
"""
Returns the candidate generator to be used in `assisted_generation`
"""
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
)
if generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
num_output_tokens=generation_config.prompt_lookup_num_tokens,
)
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
)
return candidate_generator

def _get_logits_warper(
Expand Down Expand Up @@ -995,7 +1001,7 @@ def _get_generation_mode(
generation_mode = GenerationMode.BEAM_SEARCH

# Assisted generation may extend some generation modes
if assistant_model is not None:
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
Expand Down
60 changes: 60 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,66 @@ def test_assisted_decoding_matches_greedy_search(self):
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@is_flaky()
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search

for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# enable cache
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
# prompt lookup is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}

output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, input_ids, model.config, use_cache=True)

def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
Expand Down

0 comments on commit e304f97

Please sign in to comment.