From 8223ec4df0b440be3c2f3dac41be7209aaa3b9aa Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 12 Jul 2024 10:52:19 +0100 Subject: [PATCH] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 77e67e7b2b0b17..43988780a3739a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,13 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: + previous_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - model_kwargs["cache_position"][-1] + 1, - model_kwargs["cache_position"][-1] + num_new_tokens + 1, - device=model_kwargs["cache_position"].device, - dtype=model_kwargs["cache_position"].dtype, + previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, ) - model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions)) + model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx):