Skip to content

Commit

Permalink
[Assistant Generation] Improve Encoder Decoder (#26701)
Browse files Browse the repository at this point in the history
* [Assistant Generation] Improve enc dec

* save more

* Fix logit processor checks

* Clean

* make style

* fix deprecation

* fix generation test

* Apply suggestions from code review

* fix biogpt

* make style
  • Loading branch information
patrickvonplaten authored Oct 11, 2023
1 parent 5334796 commit da69de1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 18 deletions.
18 changes: 18 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin):
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
num_assistant_tokens (`int`, *optional*, defaults to 5):
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
model requires lots of corrections, lower speed-ups are reached.
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
Defines the schedule at which max assistant tokens shall be changed during inference.
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
> Wild card
generation_kwargs:
Expand Down Expand Up @@ -294,6 +308,10 @@ def __init__(self, **kwargs):
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down
39 changes: 23 additions & 16 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}

# allow assistant_encoder_outputs to be passed if we're doing assisted generating
if "assistant_encoder_outputs" in model_kwargs:
model_args |= {"assistant_encoder_outputs"}

for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
Expand Down Expand Up @@ -1612,7 +1616,7 @@ def generate(
raise ValueError("assisted generate requires `use_cache=True`")

# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if assistant_model.config.is_encoder_decoder:
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
assistant_model_kwargs = copy.deepcopy(model_kwargs)
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
Expand Down Expand Up @@ -4347,8 +4351,14 @@ def assisted_decoding(
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
if hasattr(assistant_model, "num_assistant_tokens"):
warnings.warn(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
FutureWarning,
)
num_assistant_tokens = assistant_model.num_assistant_tokens
else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens

# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
Expand Down Expand Up @@ -4421,26 +4431,23 @@ def assisted_decoding(
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(assistant_model.max_assistant_tokens)):
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
assist_inputs,
attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
Expand Down Expand Up @@ -4495,18 +4502,18 @@ def assisted_decoding(
# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
selected_tokens = new_logits.argmax(dim=-1)

# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
Expand Down Expand Up @@ -4540,13 +4547,13 @@ def assisted_decoding(
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
if n_matches == int(num_assistant_tokens):
num_assistant_tokens += 2.0
else:
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)

# Assistant: main logic end

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need

Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,11 @@ def forward(
inputs_embeds = self.embed_tokens(input) * self.embed_scale

if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
attention_mask = torch.ones(
(inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length),
dtype=torch.bool,
device=inputs_embeds.device,
)
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
Expand Down
14 changes: 13 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,7 +2953,8 @@ def forward(self, input_ids, foo=False, **kwargs):

return outs

def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)

inputs["foo"] = foo
Expand Down Expand Up @@ -2992,3 +2993,14 @@ def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)

outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

0 comments on commit da69de1

Please sign in to comment.