Skip to content

Commit

Permalink
[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor…
Browse files Browse the repository at this point in the history
… size mismatch (huggingface#28201)

fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch

Co-authored-by: chenhanhui <[email protected]>
  • Loading branch information
2 people authored and wgifford committed Jan 21, 2024
1 parent 35b5caf commit b61f566
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]

do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
scores = torch.where(do_early_stop, early_stop_scores, scores)

return scores
16 changes: 16 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,3 +824,19 @@ def test_early_stop_processor(self):
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)

def test_early_stop_processor_multi_eos(self):
input_ids = None
eos_token_id = [2, 3]
min_eos_p = 0.1 ## some small float

scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)

esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)

0 comments on commit b61f566

Please sign in to comment.