From da69de17e86501b95396086a5b6479f645e8f70e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Oct 2023 15:52:20 +0200 Subject: [PATCH] [Assistant Generation] Improve Encoder Decoder (#26701) * [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style --- .../generation/configuration_utils.py | 18 +++++++++ src/transformers/generation/utils.py | 39 +++++++++++-------- .../models/biogpt/modeling_biogpt.py | 6 ++- tests/generation/test_utils.py | 14 ++++++- 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 18ccdb2835b411..3bd85568dcb714 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin): decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) + + num_assistant_tokens (`int`, *optional*, defaults to 5): + Defines the number of _speculative tokens_ that shall be generated by the assistant model before being + checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation + more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant + model requires lots of corrections, lower speed-ups are reached. + + num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`): + Defines the schedule at which max assistant tokens shall be changed during inference. + - `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else + reduce by 1 + - `"constant"`: `num_assistant_tokens` stays unchanged during generation + > Wild card generation_kwargs: @@ -294,6 +308,10 @@ def __init__(self, **kwargs): self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + # Assistant generation + self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) + self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 49b213cc5e92cd..a104113af891ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1241,6 +1241,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} + # allow assistant_encoder_outputs to be passed if we're doing assisted generating + if "assistant_encoder_outputs" in model_kwargs: + model_args |= {"assistant_encoder_outputs"} + for key, value in model_kwargs.items(): if value is not None and key not in model_args: unused_model_args.append(key) @@ -1612,7 +1616,7 @@ def generate( raise ValueError("assisted generate requires `use_cache=True`") # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs - if assistant_model.config.is_encoder_decoder: + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: assistant_model_kwargs = copy.deepcopy(model_kwargs) inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs @@ -4347,8 +4351,14 @@ def assisted_decoding( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # Assistant: initialize assistant-related variables - if not hasattr(assistant_model, "max_assistant_tokens"): - assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls + if hasattr(assistant_model, "num_assistant_tokens"): + warnings.warn( + "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + FutureWarning, + ) + num_assistant_tokens = assistant_model.num_assistant_tokens + else: + num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -4421,26 +4431,23 @@ def assisted_decoding( # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids - for _ in range(int(assistant_model.max_assistant_tokens)): + for _ in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] - assist_attn = torch.ones_like(candidate_input_ids) # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=assist_inputs, - decoder_attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( assist_inputs, - attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], ) else: @@ -4495,18 +4502,18 @@ def assisted_decoding( # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: - for i in range(candidate_length): + for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) if len(logits_warper) > 0: - for i in range(candidate_length): + for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Obtain the next tokens from the original model logits. if do_sample: - probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) + probs = new_logits.softmax(dim=-1) selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] else: - selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) + selected_tokens = new_logits.argmax(dim=-1) # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. @@ -4540,13 +4547,13 @@ def assisted_decoding( # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. - if n_matches == int(assistant_model.max_assistant_tokens): - assistant_model.max_assistant_tokens += 2.0 - else: - assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) + if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if n_matches == int(num_assistant_tokens): + num_assistant_tokens += 2.0 + else: + num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) # Assistant: main logic end - if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d1c471aa8090c9..ca084db5c7d0b9 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -544,7 +544,11 @@ def forward( inputs_embeds = self.embed_tokens(input) * self.embed_scale if attention_mask is None: - attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) + attention_mask = torch.ones( + (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length), + dtype=torch.bool, + device=inputs_embeds.device, + ) elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: raise ValueError( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8e3079f748dfcc..175861fd149e5e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2953,7 +2953,8 @@ def forward(self, input_ids, foo=False, **kwargs): return outs - def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs inputs = super().prepare_inputs_for_generation(*args, **kwargs) inputs["foo"] = foo @@ -2992,3 +2993,14 @@ def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): assistant_model=assistant, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = assistant.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + assistant_encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())