diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 077bc16aff8bc9..424eb4fa7e5015 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1391,43 +1391,6 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) - def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]: - if self.config.is_encoder_decoder: - key = "decoder_attention_mask" - else: - key = "attention_mask" - - if key not in model_kwargs: - return model_kwargs - - mask = model_kwargs[key] - mask_extension_length = new_mask_length - mask.shape[1] - - if mask_extension_length < 0: - raise ValueError("Cannot extend attention mask to a length less than it already is") - - model_kwargs[key] = torch.cat( - [mask, mask.new_ones((mask.shape[0], mask_extension_length))], - dim=-1, - ) - - return model_kwargs - - def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: - if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: - return model_kwargs - - token_type_ids = model_kwargs["token_type_ids"] - final_token_type = token_type_ids[:, -1].unsqueeze(-1) - extension_length = new_length - token_type_ids.shape[1] - token_type_copies = final_token_type.repeat(1, extension_length) - model_kwargs["token_type_ids"] = torch.cat( - [model_kwargs["token_type_ids"], token_type_copies], - dim=-1, - ) - - return model_kwargs - @torch.no_grad() def generate( self, @@ -4505,11 +4468,6 @@ def assisted_decoding( else: num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens - # check if assistant model accepts encoder_outputs - assistant_accepts_encoder_outputs = "encoder_outputs" in set( - inspect.signature(assistant_model.forward).parameters.keys() - ) - # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() @@ -4547,20 +4505,32 @@ def assisted_decoding( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) + # prepare assistant model's keys of inputs + assistant_kwargs = copy.copy(model_kwargs) + if assistant_model.config.is_encoder_decoder: + # both are encoder-decoder + input_ids_key = "decoder_input_ids" + attention_key = "decoder_attention_mask" + assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") + elif "assistant_encoder_outputs" in assistant_kwargs: + # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) + input_ids_key = "input_ids" + attention_key = "attention_mask" + assistant_kwargs["attention_mask"] = assistant_kwargs.get( + "decoder_attention_mask", + torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), + ) + assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") + else: + # both are decoder-only + input_ids_key = "input_ids" + attention_key = "attention_mask" + # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) # other auxiliary variables max_len = stopping_criteria[0].max_length - assistant_kv_indexing = ( - 1 - if "bloom" in assistant_model.__class__.__name__.lower() - or ( - assistant_model.config.architectures is not None - and "bloom" in assistant_model.config.architectures[0].lower() - ) - else 0 - ) this_peer_finished = False # used by synced_gpus only while True: @@ -4582,44 +4552,21 @@ def assisted_decoding( # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids for _ in range(int(num_assistant_tokens)): - # 1.1. use the assistant model to obtain the next candidate logits - if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] - # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) - new_token_len = candidate_input_ids.shape[1] - prev_seq_len - assist_inputs = candidate_input_ids[:, -new_token_len:] - # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 - if assistant_model.config.is_encoder_decoder: - assistant_model_outputs = assistant_model( - decoder_input_ids=assist_inputs, - past_key_values=model_kwargs["assistant_past_key_values"], - encoder_outputs=model_kwargs["assistant_encoder_outputs"], - ) - else: - encoder_kwargs = {} - - if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: - encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] - - assistant_model_outputs = assistant_model( - assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs - ) - else: - if assistant_model.config.is_encoder_decoder: - assistant_model_outputs = assistant_model( - decoder_input_ids=candidate_input_ids, - encoder_outputs=model_kwargs["assistant_encoder_outputs"], - ) - else: - encoder_kwargs = {} + # 1.1 prepare assistant model inputs + assistant_inputs = assistant_model.prepare_inputs_for_generation( + candidate_input_ids, + **assistant_kwargs, + ) - if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: - encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + # 1.2. check if the input ids length is correct + has_past_key_values = assistant_inputs.get("past_key_values", None) is not None + if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2): + raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") - assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs) + # 1.3. use the assistant model to obtain the next candidate logits + assistant_model_outputs = assistant_model(**assistant_inputs) - # 1.2. greedily select the next candidate token - model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values + # 1.4. greedily select the next candidate token if len(logits_processor) > 0: assistant_model_outputs.logits[:, -1, :] = logits_processor( candidate_input_ids, assistant_model_outputs.logits[:, -1, :] @@ -4627,7 +4574,13 @@ def assisted_decoding( new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - # 1.3. stop assistant generation on EOS + # 1.5. update assistant model inputs + if assistant_kwargs.get(attention_key, None) is not None: + mask = assistant_kwargs[attention_key] + assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) + assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values + + # 1.6. stop assistant generation on EOS if eos_token_id_tensor is not None: last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) last_assistant_token_is_eos = ( @@ -4646,8 +4599,10 @@ def assisted_decoding( # 2.1. Prepare the model inputs candidate_kwargs = copy.copy(model_kwargs) - candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) - candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + ) + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) @@ -4699,8 +4654,8 @@ def assisted_decoding( # 5.3. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - model_kwargs["assistant_past_key_values"] = _crop_past_key_values( - assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 + assistant_kwargs["past_key_values"] = _crop_past_key_values( + assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1 ) # the assistant does not have the token after the last match, hence the -1 # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, @@ -4761,6 +4716,12 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + # Update assistant_kwargs for the assistant's next round of generations + assistant_kwargs = _prepare_attention_mask( + assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder + ) + assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len) + # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( @@ -4938,3 +4899,37 @@ def _ranking_fast( contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx + + +def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: + """Expands or crops the model's mask for decoding purposes, to the defined length""" + + mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" + if mask_key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[mask_key] + mask_length_diff = new_length - mask.shape[1] + + if mask_length_diff < 0: + model_kwargs[mask_key] = mask[:, :mask_length_diff] + elif mask_length_diff > 0: + model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + return model_kwargs + + +def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + type_length_diff = new_length - token_type_ids.shape[1] + + if type_length_diff < 0: + token_type_ids = token_type_ids[:, :type_length_diff] + elif type_length_diff > 0: + token_type_copies = final_token_type.repeat(1, type_length_diff) + model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) + return model_kwargs diff --git a/tests/models/nllb_moe/test_modeling_nllb_moe.py b/tests/models/nllb_moe/test_modeling_nllb_moe.py index 1109948e0e7092..2e8ba30ce675a6 100644 --- a/tests/models/nllb_moe/test_modeling_nllb_moe.py +++ b/tests/models/nllb_moe/test_modeling_nllb_moe.py @@ -348,10 +348,6 @@ def test_get_loss(self): self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1]) self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0]) - @unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it") - def test_assisted_decoding_sample(self): - pass - @require_torch @require_sentencepiece diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 5458b566667903..aa226f82ae3606 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -726,10 +726,6 @@ def test_generate_with_head_masking(self): def test_disk_offload(self): pass - @unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it") - def test_assisted_decoding_sample(self): - pass - class SwitchTransformersEncoderOnlyModelTester: def __init__( diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index fe098304735931..68b9f45e155b53 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1036,10 +1036,6 @@ def test_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - @unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it") - def test_assisted_decoding_sample(self): - pass - def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task])