From b61f566c5d4808aa95474142eea8fd04886e0853 Mon Sep 17 00:00:00 2001 From: HanHui Date: Wed, 10 Jan 2024 18:46:49 +0800 Subject: [PATCH] [BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201) fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch Co-authored-by: chenhanhui --- src/transformers/generation/logits_process.py | 1 + tests/generation/test_logits_process.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index dea6f44c3ad3cb..2b1b9f5a50b6ef 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2138,6 +2138,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p + do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True) scores = torch.where(do_early_stop, early_stop_scores, scores) return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index b1b3602c927dba..95150a9c33cd36 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -824,3 +824,19 @@ def test_early_stop_processor(self): [float("-inf"), float("-inf"), scores[0][0], float("-inf")], ] self.assertListEqual(actual_scores.tolist(), expected_scores_list) + + def test_early_stop_processor_multi_eos(self): + input_ids = None + eos_token_id = [2, 3] + min_eos_p = 0.1 ## some small float + + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + scores[0].tolist(), + [float("-inf"), float("-inf"), scores[0][0], scores[0][0]], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list)