Skip to content

Commit

Permalink
fix cacheless case
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 12, 2024
1 parent 50c8260 commit d0ff984
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ def _update_model_kwargs_for_generation(
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:
new_positions = torch.arange(
model_kwargs["cache_position"][-1],
model_kwargs["cache_position"][-1] + num_new_tokens,
model_kwargs["cache_position"][-1] + 1,
model_kwargs["cache_position"][-1] + num_new_tokens + 1,
device=model_kwargs["cache_position"].device,
dtype=model_kwargs["cache_position"].dtype,
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if not empty_past_kv:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand All @@ -1571,7 +1571,7 @@ def prepare_inputs_for_generation(
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and empty_past_kv:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
Expand Down

0 comments on commit d0ff984

Please sign in to comment.