Skip to content

Commit

Permalink
Update src/transformers/generation/utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
gante and ArthurZucker committed Jul 12, 2024
1 parent d0ff984 commit 8223ec4
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,13 +692,11 @@ def _update_model_kwargs_for_generation(
if model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:
previous_positions = model_kwargs.pop("cache_position")
new_positions = torch.arange(
model_kwargs["cache_position"][-1] + 1,
model_kwargs["cache_position"][-1] + num_new_tokens + 1,
device=model_kwargs["cache_position"].device,
dtype=model_kwargs["cache_position"].dtype,
previous_positions[-1] + 1, previous_positions[-1] + num_new_tokens + 1, device=previous_positions.device, dtype=previous_positions.dtype,
)
model_kwargs["cache_position"] = torch.cat((model_kwargs["cache_position"], new_positions))
model_kwargs["cache_position"] = torch.cat((previous_positions, new_positions))
return model_kwargs

def _reorder_cache(self, past_key_values, beam_idx):
Expand Down

0 comments on commit 8223ec4

Please sign in to comment.