Skip to content

Commit

Permalink
Generate: fix candidate device placement (#28493)
Browse files Browse the repository at this point in the history
* fix candidate device

* this line shouldn't have been in
  • Loading branch information
gante authored Jan 13, 2024
1 parent e304f97 commit bc72b4e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4591,11 +4591,10 @@ def assisted_decoding(
cur_len = input_ids.shape[-1]

# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(
input_ids.to(candidate_generator.assistant_model.device)
)
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
candidate_logits = candidate_logits.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)

candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
Expand Down

0 comments on commit bc72b4e

Please sign in to comment.