diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 25ec7be1b57e9a..f99ae64fb81363 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1493,8 +1493,11 @@ def _tensor_or_none(token_kwargs, token_self, device=None): device = self.device token = token_kwargs if token_kwargs is not None else token_self - if token is None or isinstance(token, torch.Tensor): + if token is None: return token + elif isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) bos_token_id = _tensor_or_none( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b9e962a6a18c71..8fa41fbdbe2b07 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -30,7 +30,9 @@ require_auto_gptq, require_quanto, require_torch, + require_torch_gpu, require_torch_multi_accelerator, + require_torch_multi_gpu, slow, torch_device, ) @@ -3097,6 +3099,54 @@ def test_return_unprocessed_logit_scores(self): self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + @slow + @require_torch_multi_gpu + def test_assisted_decoding_in_different_gpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cuda:1" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + + @slow + @require_torch_gpu + def test_assisted_decoding_in_gpu_cpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cpu" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + @require_torch class TokenHealingTestCase(unittest.TestCase):