From d0ff984e762020a982a8bd64ffff17da0c85a7db Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Jul 2024 12:37:32 +0000 Subject: [PATCH] fix cacheless case --- src/transformers/generation/utils.py | 4 ++-- src/transformers/models/jamba/modeling_jamba.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d47169e8d9b417..77e67e7b2b0b17 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -693,8 +693,8 @@ def _update_model_kwargs_for_generation( model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: new_positions = torch.arange( - model_kwargs["cache_position"][-1], - model_kwargs["cache_position"][-1] + num_new_tokens, + model_kwargs["cache_position"][-1] + 1, + model_kwargs["cache_position"][-1] + num_new_tokens + 1, device=model_kwargs["cache_position"].device, dtype=model_kwargs["cache_position"].dtype, ) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 5682c53aea19a3..768e8e01607588 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1553,7 +1553,7 @@ def prepare_inputs_for_generation( # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: + if not empty_past_kv: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) @@ -1571,7 +1571,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and empty_past_kv: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases