diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 43988780a3739a..5c05328d0f2d3f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -692,11 +692,11 @@ def _update_model_kwargs_for_generation( if model_kwargs.get("use_cache", True): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens else: - previous_positions = model_kwargs.pop("cache_position") + past_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( - previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype, - ) - model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions)) + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs def _reorder_cache(self, past_key_values, beam_idx):