diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 77e67e7b2b0b17..43988780a3739a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,13 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: + previous_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - model_kwargs["cache_position"][-1] + 1, - model_kwargs["cache_position"][-1] + num_new_tokens + 1, - device=model_kwargs["cache_position"].device, - dtype=model_kwargs["cache_position"].dtype, + previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, ) - model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions)) + model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx):