Skip to content

Commit

Permalink
fix assisted decoding (#31401)
Browse files Browse the repository at this point in the history
* fix assisted decoding

* check None

* fix typo

* fix _prepare_special_tokens

* fix style

* fix lint

* add tests for assisted decoding

* fix style

* fix tests check
  • Loading branch information
jiqing-feng authored Jul 3, 2024
1 parent f91c16d commit 7f91f16
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,11 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)

return torch.tensor(token, device=device, dtype=torch.long)

bos_token_id = _tensor_or_none(
Expand Down
50 changes: 50 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
require_auto_gptq,
require_quanto,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -3097,6 +3099,54 @@ def test_return_unprocessed_logit_scores(self):
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)

@slow
@require_torch_multi_gpu
def test_assisted_decoding_in_different_gpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cuda:1"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)

@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cpu"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down

0 comments on commit 7f91f16

Please sign in to comment.