From c9f6e5e35156e068b227dd9b15521767f6afd4d2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 1 Apr 2024 12:45:25 +0100 Subject: [PATCH] Generate: move misplaced test (#29902) --- tests/generation/test_utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 99f6e84a3036e0..5c73e92a77a8a1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1467,17 +1467,6 @@ def test_past_key_values_format(self): past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) ) - def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self): - article = "Today a dragon flew over Paris." - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - inputs_embeds = model.get_input_embeddings()(input_ids) - - model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None) - with self.assertRaises(ValueError): - model.generate(max_length=20, bos_token_id=None) - def test_generate_from_inputs_embeds_decoder_only(self): # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function @@ -2817,3 +2806,16 @@ def test_return_unprocessed_logit_scores(self): self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + + def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self): + article = "Today a dragon flew over Paris." + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) + + model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None) + + # bos_token_id is required when no input ids nor inputs_embeds is passed + with self.assertRaises(ValueError): + model.generate(max_length=20, bos_token_id=None)