From 5c7f159670ca0353371e1b78f237549977b8e4e4 Mon Sep 17 00:00:00 2001 From: Taylor Jackle Spriggs Date: Fri, 3 May 2024 18:01:51 +0000 Subject: [PATCH] Fix decoder-only generation with embeddings for gpt2 Co-authored-by: Taylor Jackle Spriggs Co-authored-by: Soila Kavulya --- .../habana/transformers/generation/utils.py | 137 ++++++++++++++---- .../models/blip/modeling_blip_text.py | 5 +- .../models/bloom/modeling_bloom.py | 5 +- .../models/codegen/modeling_codegen.py | 6 +- .../models/decilm/modeling_decilm.py | 5 +- .../models/falcon/modeling_falcon.py | 5 +- .../models/gemma/modeling_gemma.py | 5 +- .../transformers/models/gpt2/modeling_gpt2.py | 10 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 5 +- .../models/gpt_neox/modeling_gpt_neox.py | 5 +- .../transformers/models/gptj/modeling_gptj.py | 5 +- .../models/llama/modeling_llama.py | 5 +- .../models/mamba/modeling_mamba.py | 5 +- .../models/mistral/modeling_mistral.py | 5 +- .../models/mixtral/modeling_mixtral.py | 5 +- .../transformers/models/mpt/modeling_mpt.py | 6 +- .../transformers/models/opt/modeling_opt.py | 5 +- .../models/persimmon/modeling_persimmon.py | 5 +- .../transformers/models/phi/modeling_phi.py | 5 +- .../models/qwen2/modeling_qwen2.py | 5 +- .../models/stablelm/modeling_stablelm.py | 5 +- .../models/starcoder2/modeling_starcoder2.py | 5 +- .../transformers/models/t5/modeling_t5.py | 5 +- .../tests/generation/test_utils.py | 2 +- 24 files changed, 206 insertions(+), 50 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d333986679..0016d87138 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -140,7 +140,10 @@ def get_final_stopping_criteria(x): if isinstance(x, bool): return x elif torch.is_tensor(x): - return all(x) + if x.dim() > 0: + return all(x) + else: + return x else: raise TypeError(f"The stopping criteria should be either a boolean or a torch.tensor but got {type(x)}.") @@ -297,6 +300,7 @@ def _expand_dict_for_generation(dict_to_expand): key != "token_idx" and key != "decoder_input_ids" and key != "cache_position" + and key != "inputs_embeds_offset" and dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) ): @@ -914,14 +918,44 @@ def generate( # only pad if bucket_size < -1. If we are bucketing (bucket_size > 0), then that is taken care in greedy_search() if not is_greedy_or_beam_and_bucket: # token_idx is the current index in the generation process, it is incremented each time a new token is generated - token_idx = inputs_tensor.shape[-1] - model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) - model_kwargs["token_idx_cpu"] = token_idx + token_idx = inputs_tensor.shape[1] if generation_config.max_new_tokens is None: generation_config.max_new_tokens = generation_config.max_length - token_idx - inputs_tensor = torch.nn.functional.pad( - inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id - ) + if "inputs_embeds" in model_kwargs: + if "input_ids" in model_kwargs: + inputs_embeds_offset = ( + model_kwargs["input_ids"].shape[1] - model_kwargs["inputs_embeds"].shape[1] + ) + else: + inputs_embeds_offset = -model_kwargs["inputs_embeds"].shape[1] + + model_kwargs["inputs_embeds_offset"] = torch.tensor( + inputs_embeds_offset, device=inputs_tensor.device + ) + model_kwargs["inputs_embeds"] = torch.nn.functional.pad( + model_kwargs["inputs_embeds"], + (0, 0, 0, generation_config.max_new_tokens), + value=generation_config.pad_token_id, + ) + + if model_input_name == "inputs_embeds": + inputs_tensor = torch.nn.functional.pad( + inputs_tensor, + (0, 0, 0, generation_config.max_new_tokens), + value=generation_config.pad_token_id, + ) + model_kwargs["input_ids"] = torch.nn.functional.pad( + model_kwargs["input_ids"], + (0, generation_config.max_new_tokens), + value=generation_config.pad_token_id, + ) + else: + inputs_tensor = torch.nn.functional.pad( + inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id + ) + model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) + model_kwargs["token_idx_cpu"] = token_idx + for other_inputs in ["attention_mask", "token_type_ids"]: if model_kwargs.get(other_inputs) is not None: model_kwargs[other_inputs] = torch.nn.functional.pad( @@ -1612,8 +1646,6 @@ def _contrastive_search( # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] if not ignore_eos: unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) @@ -1640,7 +1672,10 @@ def _contrastive_search( top_k_ids = None if token_idx is not None: # Update cur_len in case of static shapes - cur_len = token_idx.item() + if "inputs_embeds_offset" in model_kwargs: + cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item() + else: + cur_len = token_idx.item() time_to_first_token_done = False model_kwargs["pad_done"] = False @@ -1756,8 +1791,13 @@ def _contrastive_search( pad_amount = input_ids.shape[-1] - top_k_ids.shape[-1] top_k_ids = torch.nn.functional.pad(top_k_ids, (0, pad_amount), value=pad_token_id) + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] - 1 + if "inputs_embeds_offset" in model_kwargs + else token_idx - 1 + ) top_k_probs, top_k_prob_ids = torch.topk(next_probs, dim=-1, k=top_k) - top_k_ids[:, :, token_idx - 1] = top_k_prob_ids + top_k_ids[:, :, idx] = top_k_prob_ids else: top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) @@ -1883,8 +1923,14 @@ def _contrastive_search( # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores # (model confidence minus degeneration penalty); (6) decoder hidden_states top_k_indices = torch.arange(len(top_k_ids), device=input_ids.device) + if token_idx is not None: - next_tokens = top_k_ids[top_k_indices, selected_idx, token_idx - 1] + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] - 1 + if "inputs_embeds_offset" in model_kwargs + else token_idx - 1 + ) + next_tokens = top_k_ids[top_k_indices, selected_idx, idx] else: next_tokens = top_k_ids[top_k_indices, selected_idx] next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) @@ -1976,9 +2022,12 @@ def _contrastive_search( # update generated ids, model inputs, and length for next step if token_idx is not None: # Use token_idx-1 since token index is incremented twice in first iteration - input_ids.index_copy_( - 1, token_idx - 1, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] - 1 + if "inputs_embeds_offset" in model_kwargs + else token_idx - 1 ) + input_ids.index_copy_(1, idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens) else: input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2217,7 +2266,10 @@ def _sample( token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes - cur_len = token_idx.item() + if "inputs_embeds_offset" in model_kwargs: + cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item() + else: + cur_len = token_idx.item() time_to_first_token_done = False model_kwargs["pad_done"] = False @@ -2319,9 +2371,12 @@ def _sample( next_tokens = next_tokens.to(input_ids.dtype) if token_idx is not None: - input_ids.index_copy_( - 1, token_idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] + if "inputs_embeds_offset" in model_kwargs + else token_idx ) + input_ids.index_copy_(1, idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens) else: input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2514,7 +2569,11 @@ def _beam_search( token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes - cur_len = token_idx.item() + if "inputs_embeds_offset" in model_kwargs: + cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item() + else: + cur_len = token_idx.item() + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: @@ -2729,7 +2788,12 @@ def expand_if_needed(tensor, new_size, value, dim=-1): ) # (batch_size * num_beams, vocab_size) if token_idx is not None: - next_token_scores_processed = logits_processor(input_ids[:, :token_idx], next_token_scores) + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] + if "inputs_embeds_offset" in model_kwargs + else token_idx + ) + next_token_scores_processed = logits_processor(input_ids[:, :idx], next_token_scores) else: next_token_scores_processed = logits_processor(input_ids, next_token_scores) if do_sample: @@ -2836,8 +2900,13 @@ def expand_if_needed(tensor, new_size, value, dim=-1): if token_idx is not None: input_ids = torch.index_select(input_ids, 0, beam_idx) + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] + if "inputs_embeds_offset" in model_kwargs + else token_idx + ) input_ids.index_copy_( - 1, token_idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens + 1, idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens ) else: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) @@ -3091,12 +3160,15 @@ def _constrained_beam_search( num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes - cur_len = token_idx.item() + if "inputs_embeds_offset" in model_kwargs: + cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item() + else: + cur_len = token_idx.item() + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: @@ -3129,7 +3201,11 @@ def _constrained_beam_search( this_peer_finished = False - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + if token_idx is not None: + # Update cur_len in case of static shapes + decoder_prompt_len = cur_len + else: + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder hb_profer = HabanaProfile( warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes @@ -3234,8 +3310,14 @@ def _constrained_beam_search( if token_idx is not None: input_ids = input_ids[beam_idx, :] + idx = ( + token_idx + model_kwargs["inputs_embeds_offset"] + if "inputs_embeds_offset" in model_kwargs + else token_idx + ) + input_ids.index_copy_( - 1, token_idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens + 1, idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens ) else: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) @@ -3437,7 +3519,10 @@ def _assisted_decoding( if token_idx is not None: # Update cur_len in case of static shapes - cur_len = token_idx.item() + if "inputs_embeds_offset" in model_kwargs: + cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item() + else: + cur_len = token_idx.item() else: cur_len = input_ids.shape[-1] diff --git a/optimum/habana/transformers/models/blip/modeling_blip_text.py b/optimum/habana/transformers/models/blip/modeling_blip_text.py index 23d4ee3f3c..3b603d2efd 100644 --- a/optimum/habana/transformers/models/blip/modeling_blip_text.py +++ b/optimum/habana/transformers/models/blip/modeling_blip_text.py @@ -505,7 +505,10 @@ def gaudi_BlipTextLMHead_prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in model_kwargs: + idx = idx + model_kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index df99463c15..b9f9f035f5 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -489,7 +489,10 @@ def prepare_inputs_for_generation( if token_idx is None: input_ids = input_ids[:, -1].unsqueeze(-1) else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: diff --git a/optimum/habana/transformers/models/codegen/modeling_codegen.py b/optimum/habana/transformers/models/codegen/modeling_codegen.py index 536cb5d423..0f9d0256bc 100644 --- a/optimum/habana/transformers/models/codegen/modeling_codegen.py +++ b/optimum/habana/transformers/models/codegen/modeling_codegen.py @@ -320,7 +320,11 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) + if token_type_ids is not None: token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1) else: diff --git a/optimum/habana/transformers/models/decilm/modeling_decilm.py b/optimum/habana/transformers/models/decilm/modeling_decilm.py index 562033f2cb..e11a9732dc 100644 --- a/optimum/habana/transformers/models/decilm/modeling_decilm.py +++ b/optimum/habana/transformers/models/decilm/modeling_decilm.py @@ -373,7 +373,10 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index a7a0c0e920..57b16335ba 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -983,7 +983,10 @@ def prepare_inputs_for_generation( bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 6c537dfa31..73f34387e2 100644 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -419,7 +419,10 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] else: # past_length += token_idx - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index b2ec2c0229..afd2689633 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -476,7 +476,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] @@ -609,7 +612,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7d2a065593..692b8b618f 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -558,7 +558,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) if token_type_ids is not None: token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1) else: diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index aa6423d2b1..b3c58fbec9 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -379,7 +379,10 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past is used if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 4793766f6e..334c828153 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -563,7 +563,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1abbfab12d..014f0e3fae 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1346,7 +1346,10 @@ def prepare_inputs_for_generation( bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py index ea7c112c7d..c577ff1a41 100644 --- a/optimum/habana/transformers/models/mamba/modeling_mamba.py +++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py @@ -61,7 +61,10 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( # the length of `cache_params.conv_states`, which is `config.conv_kernel` cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, torch.arange(token_idx_cpu, device=input_ids.device)) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 7d95e548ce..4be3a882a2 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -810,7 +810,10 @@ def prepare_inputs_for_generation( ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) elif reuse_cache and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 43dfc7e48a..b76d3e4e2d 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -841,7 +841,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] diff --git a/optimum/habana/transformers/models/mpt/modeling_mpt.py b/optimum/habana/transformers/models/mpt/modeling_mpt.py index 7cefc4e37f..96535671da 100755 --- a/optimum/habana/transformers/models/mpt/modeling_mpt.py +++ b/optimum/habana/transformers/models/mpt/modeling_mpt.py @@ -354,7 +354,11 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) + # Converting back to tuples as it should be, so there's no type mismatch when calling graph past_key_values = tuple([tuple(kv) for kv in past_key_values]) elif bucket_internal and token_idx is not None: diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index 9f113453e9..bab0b210fb 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -507,7 +507,10 @@ def prepare_inputs_for_generation( ): if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py index 4c7b24b988..41df7ee507 100644 --- a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py +++ b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py @@ -425,7 +425,10 @@ def prepare_inputs_for_generation( ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index 1e21735add..977421d811 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -627,7 +627,10 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 0c8970dd88..47c4de9f1f 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -859,7 +859,10 @@ def prepare_inputs_for_generation( bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index 08becc263a..0d5455e2fe 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -454,7 +454,10 @@ def prepare_inputs_for_generation( ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index 36d5379e4f..3d11529d5c 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -827,7 +827,10 @@ def prepare_inputs_for_generation( reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] diff --git a/optimum/habana/transformers/models/t5/modeling_t5.py b/optimum/habana/transformers/models/t5/modeling_t5.py index 17b0e49a97..bae355d989 100644 --- a/optimum/habana/transformers/models/t5/modeling_t5.py +++ b/optimum/habana/transformers/models/t5/modeling_t5.py @@ -609,7 +609,10 @@ def gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + idx = token_idx - 1 + if "inputs_embeds_offset" in kwargs: + idx = idx + kwargs["inputs_embeds_offset"] + input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index b052809227..512935e9dd 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -2185,7 +2185,7 @@ def test_generate_from_inputs_embeds_decoder_only(self): ) self.assertListEqual( outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids[:, 1:].tolist(), + outputs_from_embeds_wo_ids.tolist(), ) def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):