Skip to content

Commit

Permalink
Fix beam score calculation issue for tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
VsonicV committed Dec 3, 2023
1 parent cd8ec99 commit c537a80
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,12 +2267,14 @@ def unflatten_beam_dim(tensor, num_beams, batch_axis=0):
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None

# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = shape_list(input_ids)
batch_size, num_beams, decoder_prompt_len = shape_list(input_ids)
# cur_len represents the length of generated tokens so far
cur_len = 0

# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
pad_token_id or 0
)
input_ids_padding = tf.ones(
(batch_size, num_beams, max_length - cur_len - decoder_prompt_len), dtype=tf.int32
) * (pad_token_id or 0)
running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0)

Expand All @@ -2286,8 +2288,8 @@ def unflatten_beam_dim(tensor, num_beams, batch_axis=0):
scores = tf.ones((batch_size, num_beams)) * -1.0e9

# per batch beam indices
running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1

# flatten beam dim
if "encoder_outputs" in model_kwargs:
Expand All @@ -2308,14 +2310,15 @@ def beam_search_cond_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
):
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
False
"""
# 1. is less than max length?
not_max_length_yet = cur_len < max_length
not_max_length_yet = (cur_len + decoder_prompt_len) < max_length

# 2. can the new beams still improve?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
Expand All @@ -2324,7 +2327,7 @@ def beam_search_cond_fn(
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty)
else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
worst_finished_score = tf.where(
Expand All @@ -2346,6 +2349,7 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
):
"""
Expand All @@ -2354,9 +2358,9 @@ def beam_search_body_fn(
"""
# 1. Forward current tokens
if model_kwargs.get("past_key_values") is None or needs_full_input:
input_ids = running_sequences[:, :, :cur_len]
input_ids = running_sequences[:, :, : cur_len + decoder_prompt_len]
else:
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
input_ids = tf.expand_dims(running_sequences[:, :, cur_len + decoder_prompt_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(
flatten_beam_dim(input_ids), use_cache=use_cache, **model_kwargs
)
Expand Down Expand Up @@ -2427,7 +2431,12 @@ def beam_search_body_fn(
indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep])
indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size])
update_indices = tf.stack(
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
[
indices_batch,
indices_beam,
tf.broadcast_to(cur_len + decoder_prompt_len, [batch_size * beams_to_keep]),
],
axis=-1,
)
topk_sequences = tf.tensor_scatter_nd_update(
tensor=topk_running_sequences,
Expand All @@ -2439,6 +2448,9 @@ def beam_search_body_fn(
batch_modified_indices = topk_current_beam_indices + tf.broadcast_to(
tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape
)
update_indices = tf.stack(
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
)
topk_beam_indices = tf.tensor_scatter_nd_update(
tensor=topk_running_beam_indices,
indices=update_indices,
Expand All @@ -2450,12 +2462,13 @@ def beam_search_body_fn(
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
if eos_token_id is None:
eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len].shape, dtype=tf.bool)
eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len + decoder_prompt_len].shape, dtype=tf.bool)
else:
eos_in_next_token = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(
topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape
topk_sequences[:, :, cur_len + decoder_prompt_len],
[len(eos_token_id)] + topk_sequences[:, :, cur_len + decoder_prompt_len].shape,
),
tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1),
),
Expand Down Expand Up @@ -2483,7 +2496,7 @@ def beam_search_body_fn(
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
topk_log_probs = topk_log_probs / (tf.cast(cur_len + 1, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) & (early_stopping is True)
Expand Down Expand Up @@ -2546,6 +2559,7 @@ def beam_search_body_fn(
next_scores,
next_beam_indices,
next_is_sent_finished,
decoder_prompt_len,
next_model_kwargs,
)

Expand All @@ -2560,6 +2574,7 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
) = beam_search_body_fn(
cur_len,
Expand All @@ -2570,12 +2585,13 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
)

# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
maximum_iterations = max_length - cur_len
maximum_iterations = max_length - cur_len - decoder_prompt_len
(
cur_len,
running_sequences,
Expand All @@ -2585,6 +2601,7 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
_,
) = tf.while_loop(
beam_search_cond_fn,
Expand All @@ -2598,11 +2615,14 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
),
maximum_iterations=maximum_iterations,
)

print("decoder_prompt_len: {}".format(decoder_prompt_len))

# 6. prepare outputs
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
Expand All @@ -2621,7 +2641,7 @@ def beam_search_body_fn(

if not use_xla:
# Cut for backward compatibility
sequences = sequences[:, :cur_len]
sequences = sequences[:, : cur_len + decoder_prompt_len]
beam_indices = beam_indices[:, :cur_len]

if return_dict_in_generate:
Expand Down

0 comments on commit c537a80

Please sign in to comment.