Skip to content

Commit

Permalink
fix assisted decoding assistant model inputs (#27503)
Browse files Browse the repository at this point in the history
* fix assisted decoding attention_cat

* fix attention_mask for assisted decoding

* fix attention_mask len

* fix attn len

* Use a more clean way to prepare assistant models inputs

* fix param meaning

* fix param name

* fix assistant model inputs

* update token type ids

* fix assistant kwargs copy

* add encoder-decoder tests of assisted decoding

* check if assistant kwargs contains updated keys

* revert test

* fix whisper tests

* fix assistant kwargs

* revert whisper test

* delete _extend funcs
  • Loading branch information
jiqing-feng authored Nov 27, 2023
1 parent 307cf3a commit 1d7f406
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 103 deletions.
177 changes: 86 additions & 91 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,43 +1391,6 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
if self.config.is_encoder_decoder:
key = "decoder_attention_mask"
else:
key = "attention_mask"

if key not in model_kwargs:
return model_kwargs

mask = model_kwargs[key]
mask_extension_length = new_mask_length - mask.shape[1]

if mask_extension_length < 0:
raise ValueError("Cannot extend attention mask to a length less than it already is")

model_kwargs[key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
dim=-1,
)

return model_kwargs

def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs

token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
extension_length = new_length - token_type_ids.shape[1]
token_type_copies = final_token_type.repeat(1, extension_length)
model_kwargs["token_type_ids"] = torch.cat(
[model_kwargs["token_type_ids"], token_type_copies],
dim=-1,
)

return model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -4505,11 +4468,6 @@ def assisted_decoding(
else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens

# check if assistant model accepts encoder_outputs
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
inspect.signature(assistant_model.forward).parameters.keys()
)

# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
Expand Down Expand Up @@ -4547,20 +4505,32 @@ def assisted_decoding(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# prepare assistant model's keys of inputs
assistant_kwargs = copy.copy(model_kwargs)
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
input_ids_key = "decoder_input_ids"
attention_key = "decoder_attention_mask"
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
elif "assistant_encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
input_ids_key = "input_ids"
attention_key = "attention_mask"
assistant_kwargs["attention_mask"] = assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
else:
# both are decoder-only
input_ids_key = "input_ids"
attention_key = "attention_mask"

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

# other auxiliary variables
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)

this_peer_finished = False # used by synced_gpus only
while True:
Expand All @@ -4582,52 +4552,35 @@ def assisted_decoding(
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
encoder_kwargs = {}

if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]

assistant_model_outputs = assistant_model(
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
)
else:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
encoder_kwargs = {}
# 1.1 prepare assistant model inputs
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
**assistant_kwargs,
)

if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
# 1.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")

assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
# 1.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = assistant_model(**assistant_inputs)

# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
# 1.4. greedily select the next candidate token
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)

# 1.3. stop assistant generation on EOS
# 1.5. update assistant model inputs
if assistant_kwargs.get(attention_key, None) is not None:
mask = assistant_kwargs[attention_key]
assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values

# 1.6. stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
Expand All @@ -4646,8 +4599,10 @@ def assisted_decoding(

# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

Expand Down Expand Up @@ -4699,8 +4654,8 @@ def assisted_decoding(
# 5.3. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
assistant_kwargs["past_key_values"] = _crop_past_key_values(
assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
Expand Down Expand Up @@ -4761,6 +4716,12 @@ def assisted_decoding(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# Update assistant_kwargs for the assistant's next round of generations
assistant_kwargs = _prepare_attention_mask(
assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
)
assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
Expand Down Expand Up @@ -4938,3 +4899,37 @@ def _ranking_fast(
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
_, selected_idx = contrastive_score.max(dim=-1) # [B]
return selected_idx


def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""

mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs

mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]

if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs


def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs

token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]

if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs
4 changes: 0 additions & 4 deletions tests/models/nllb_moe/test_modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,6 @@ def test_get_loss(self):
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


@require_torch
@require_sentencepiece
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,6 @@ def test_generate_with_head_masking(self):
def test_disk_offload(self):
pass

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


class SwitchTransformersEncoderOnlyModelTester:
def __init__(
Expand Down
4 changes: 0 additions & 4 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,6 @@ def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
Expand Down

0 comments on commit 1d7f406

Please sign in to comment.