diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ad5289f120ea19..e28615bb3c2621 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -169,6 +169,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ + input_ids = input_ids.to(self.assistant_model.device) + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ef9e19c8b11057..f36f76a27a390a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4591,11 +4591,10 @@ def assisted_decoding( cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates( - input_ids.to(candidate_generator.assistant_model.device) - ) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) candidate_input_ids = candidate_input_ids.to(self.device) - candidate_logits = candidate_logits.to(self.device) + if candidate_logits is not None: + candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] last_assistant_token_is_eos = (