From e304f9769ca0bd9ef6fdb63a0568ecbd74af4434 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Sat, 13 Jan 2024 22:45:58 +0530 Subject: [PATCH] Adding Prompt lookup decoding (#27775) * MVP * fix ci * more ci * remove redundant kwarg * added and wired up PromptLookupCandidateGenerator * rebased with main, working * removed print * style fixes * fix test * fixed tests * added test for prompt lookup decoding * fixed circleci * fixed test issue * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Joao Gante Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../generation/candidate_generator.py | 92 +++++++++++++++++++ .../generation/configuration_utils.py | 3 + src/transformers/generation/utils.py | 24 +++-- tests/generation/test_utils.py | 60 ++++++++++++ 4 files changed, 170 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 4e43cae224bb5e..ad5289f120ea19 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -226,6 +226,98 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) +class PromptLookupCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up + likely continuations in the provided prompt (input_ids) itself. + Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding + + Args: + max_matching_ngram_size (`int`): + The maximum ngram size to be considered for matching in the prompt + num_output_tokens (`int`): + The number of tokens to be output as candidate tokens. + """ + + def __init__( + self, + num_output_tokens: int = 10, + max_matching_ngram_size: int = 2, + ): + self.num_output_tokens = num_output_tokens + self.max_matching_ngram_size = max_matching_ngram_size + + if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: + raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + input_length = input_ids.size(1) + + chosen_ids = None + match_found = False + for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = input_ids[0, -ngram_size:] + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + self.num_output_tokens + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: + chosen_ids = input_ids[0, start_idx:end_idx] + match_found = True + break + if match_found: + break + + if chosen_ids is None or len(chosen_ids) == 0: + # Need to make a dummy tensor to avoid errors + chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device) + + # Now need extend input_ids with chosen_ids + chosen_ids = chosen_ids.unsqueeze(0) + candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) + # assisted_generation expects logits as well, but we don't have those here, so returning None + return candidate_input_ids, None + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Currently does nothing + return + + def _crop_past_key_values(model, past_key_values, maximum_length): """Crops the past key values up to a certain maximum length.""" new_past = [] diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 21fe916a7aabd2..4353a113223870 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -320,6 +320,9 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # Prompt lookup decoding + self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) + # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 43675aed40a940..ef9e19c8b11057 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -40,6 +40,7 @@ from .candidate_generator import ( AssistedCandidateGenerator, CandidateGenerator, + PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, _prepare_token_type_ids, @@ -908,14 +909,19 @@ def _get_candidate_generator( """ Returns the candidate generator to be used in `assisted_generation` """ - candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - logits_processor=logits_processor, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - ) + if generation_config.prompt_lookup_num_tokens is not None: + candidate_generator = PromptLookupCandidateGenerator( + num_output_tokens=generation_config.prompt_lookup_num_tokens, + ) + else: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + ) return candidate_generator def _get_logits_warper( @@ -995,7 +1001,7 @@ def _get_generation_mode( generation_mode = GenerationMode.BEAM_SEARCH # Assisted generation may extend some generation modes - if assistant_model is not None: + if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None: if generation_mode in ("greedy_search", "sample"): generation_mode = GenerationMode.ASSISTED_GENERATION else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 973f54f0039701..bfae5e882778a8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1569,6 +1569,66 @@ def test_assisted_decoding_matches_greedy_search(self): for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) + @is_flaky() + def test_prompt_lookup_decoding_matches_greedy_search(self): + # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. + # This test is mostly a copy of test_assisted_decoding_matches_greedy_search + + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest("Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest("May fix in the future: need model-specific fixes") + + # enable cache + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest("This model doesn't support caching") + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of + # prompt lookup is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": False, + "output_scores": True, + "output_hidden_states": True, + "output_attentions": True, + "return_dict_in_generate": True, + } + + output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) + output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + # The two outputs must match and their shape must be as expected + self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist()) + for output in (output_greedy, output_prompt_lookup): + self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with