Skip to content

Commit

Permalink
Fix issues in add and is_done for BeamHypotheses
Browse files Browse the repository at this point in the history
  • Loading branch information
VsonicV committed Dec 3, 2023
1 parent 2c658b5 commit cd8ec99
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
46 changes: 26 additions & 20 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,14 @@ def process(
else:
beam_index = None

# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
# add up to the length which the next_scores is calculated on
generated_len = input_ids[batch_beam_idx].shape[-1] + 1 - decoder_prompt_len

self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=generated_len,
)
else:
# add next predicted token since it is not eos_token
Expand All @@ -308,7 +307,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -348,7 +347,8 @@ def finalize(
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
Expand Down Expand Up @@ -617,16 +617,14 @@ def process(
else:
beam_index = None

# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
# add up to the length which the next_scores is calculated on
generated_len = input_ids[batch_beam_idx].shape[-1] + 1 - decoder_prompt_len

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=generated_len,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -660,7 +658,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -846,9 +844,8 @@ def finalize(
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
ids_collect.append(beam_id)

# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
Expand All @@ -859,7 +856,8 @@ def finalize(
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break

Expand Down Expand Up @@ -956,12 +954,16 @@ def add(
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
else:
raise ValueError("`generated_len` has to be defined for beam score calculation")

if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
Expand All @@ -971,7 +973,7 @@ def add(
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
Expand All @@ -996,7 +998,11 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def test_eos_token_id_int_and_list_beam_search(self):
"num_beams": 3,
}
if is_pt:
expectation = 20
expectation = 13
else:
# TODO (joao): fix me
expectation = 13
Expand Down

0 comments on commit cd8ec99

Please sign in to comment.