From 02f741735c025af4b64d45ce80df7073a24f0676 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 12 Jul 2024 10:09:06 +0000 Subject: [PATCH] make fixup --- src/transformers/generation/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 43988780a3739a..5c05328d0f2d3f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,11 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: - previous_positions = model_kwargs.pop("cache_position") + past_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, - ) - model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx):