From cd8ec99e97278b20664b3f58598dd74b9038f21c Mon Sep 17 00:00:00 2001 From: VsonicV Date: Sun, 3 Dec 2023 04:18:54 +0000 Subject: [PATCH] Fix issues in add and is_done for BeamHypotheses --- src/transformers/generation/beam_search.py | 46 ++++++++++++--------- tests/generation/test_framework_agnostic.py | 2 +- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index a29d34306f830f..4648483eac7a06 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -279,15 +279,14 @@ def process( else: beam_index = None - # skip the corner case where the very first generated token is eos_token - if decoder_prompt_len == input_ids.shape[-1]: - continue + # add up to the length which the next_scores is calculated on + generated_len = input_ids[batch_beam_idx].shape[-1] + 1 - decoder_prompt_len self._beam_hyps[batch_group_idx].add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, - decoder_prompt_len=decoder_prompt_len, + generated_len=generated_len, ) else: # add next predicted token since it is not eos_token @@ -308,7 +307,7 @@ def process( # Check if we are done so that we can save a pad step if all(done) self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( - next_scores[batch_idx].max().item(), cur_len + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len ) return UserDict( @@ -348,7 +347,8 @@ def finalize( final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) @@ -617,16 +617,14 @@ def process( else: beam_index = None - # skip the corner case where the only constraint token is - # eos_token and the very first generated token is eos_token - if decoder_prompt_len == input_ids.shape[-1]: - continue + # add up to the length which the next_scores is calculated on + generated_len = input_ids[batch_beam_idx].shape[-1] + 1 - decoder_prompt_len beam_hyp.add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, - decoder_prompt_len=decoder_prompt_len, + generated_len=generated_len, ) else: # add next predicted token since it is not eos_token @@ -660,7 +658,7 @@ def process( # Check if we are done so that we can save a pad step if all(done) self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len ) return UserDict( @@ -846,9 +844,8 @@ def finalize( completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) if completes_constraint: beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add( - final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len - ) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) ids_collect.append(beam_id) # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful @@ -859,7 +856,8 @@ def finalize( batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, generated_len=generated_len) if len(ids_collect) >= self.num_beam_hyps_to_keep: break @@ -956,12 +954,16 @@ def add( hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None, - decoder_prompt_len: Optional[int] = 0, + generated_len: Optional[int] = None, ): """ Add a new hypothesis to the list. """ - score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty) + if generated_len is not None: + score = sum_logprobs / (generated_len**self.length_penalty) + else: + raise ValueError("`generated_len` has to be defined for beam score calculation") + if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp, beam_indices)) if len(self) > self.num_beams: @@ -971,7 +973,7 @@ def add( else: self.worst_score = min(score, self.worst_score) - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: int) -> bool: """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. @@ -996,7 +998,11 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain # its max this way if self.length_penalty > 0.0: - highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty + if self.max_length <= decoder_prompt_len: + raise ValueError("max_length is not larger than decoder prompt length") + highest_attainable_score = ( + best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty + ) # the opposite logic applies here (max `highest_attainable_score` from `cur_len`) else: highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index 8a269801640ebf..995a7572e6114e 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -634,7 +634,7 @@ def test_eos_token_id_int_and_list_beam_search(self): "num_beams": 3, } if is_pt: - expectation = 20 + expectation = 13 else: # TODO (joao): fix me expectation = 13