diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 59848c3c85905d..acad9d6baa0f12 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -2267,12 +2267,14 @@ def unflatten_beam_dim(tensor, num_beams, batch_axis=0): decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - batch_size, num_beams, cur_len = shape_list(input_ids) + batch_size, num_beams, decoder_prompt_len = shape_list(input_ids) + # cur_len represents the length of generated tokens so far + cur_len = 0 # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` - input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( - pad_token_id or 0 - ) + input_ids_padding = tf.ones( + (batch_size, num_beams, max_length - cur_len - decoder_prompt_len), dtype=tf.int32 + ) * (pad_token_id or 0) running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1) sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0) @@ -2286,8 +2288,8 @@ def unflatten_beam_dim(tensor, num_beams, batch_axis=0): scores = tf.ones((batch_size, num_beams)) * -1.0e9 # per batch beam indices - running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1 - beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1 + running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1 + beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1 # flatten beam dim if "encoder_outputs" in model_kwargs: @@ -2308,6 +2310,7 @@ def beam_search_cond_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, model_kwargs, ): """ @@ -2315,7 +2318,7 @@ def beam_search_cond_fn( False """ # 1. is less than max length? - not_max_length_yet = cur_len < max_length + not_max_length_yet = (cur_len + decoder_prompt_len) < max_length # 2. can the new beams still improve? # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion @@ -2324,7 +2327,7 @@ def beam_search_cond_fn( # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. if early_stopping == "never" and length_penalty > 0.0: - best_running_score = running_scores[:, :1] / (max_length**length_penalty) + best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty) else: best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) worst_finished_score = tf.where( @@ -2346,6 +2349,7 @@ def beam_search_body_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, model_kwargs, ): """ @@ -2354,9 +2358,9 @@ def beam_search_body_fn( """ # 1. Forward current tokens if model_kwargs.get("past_key_values") is None or needs_full_input: - input_ids = running_sequences[:, :, :cur_len] + input_ids = running_sequences[:, :, : cur_len + decoder_prompt_len] else: - input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1) + input_ids = tf.expand_dims(running_sequences[:, :, cur_len + decoder_prompt_len - 1], -1) model_inputs = self.prepare_inputs_for_generation( flatten_beam_dim(input_ids), use_cache=use_cache, **model_kwargs ) @@ -2427,7 +2431,12 @@ def beam_search_body_fn( indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep]) indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size]) update_indices = tf.stack( - [indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1 + [ + indices_batch, + indices_beam, + tf.broadcast_to(cur_len + decoder_prompt_len, [batch_size * beams_to_keep]), + ], + axis=-1, ) topk_sequences = tf.tensor_scatter_nd_update( tensor=topk_running_sequences, @@ -2439,6 +2448,9 @@ def beam_search_body_fn( batch_modified_indices = topk_current_beam_indices + tf.broadcast_to( tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape ) + update_indices = tf.stack( + [indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1 + ) topk_beam_indices = tf.tensor_scatter_nd_update( tensor=topk_running_beam_indices, indices=update_indices, @@ -2450,12 +2462,13 @@ def beam_search_body_fn( # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large negative value. if eos_token_id is None: - eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len].shape, dtype=tf.bool) + eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len + decoder_prompt_len].shape, dtype=tf.bool) else: eos_in_next_token = tf.math.reduce_any( tf.equal( tf.broadcast_to( - topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape + topk_sequences[:, :, cur_len + decoder_prompt_len], + [len(eos_token_id)] + topk_sequences[:, :, cur_len + decoder_prompt_len].shape, ), tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1), ), @@ -2483,7 +2496,7 @@ def beam_search_body_fn( # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam - topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) + topk_log_probs = topk_log_probs / (tf.cast(cur_len + 1, dtype=tf.float32) ** length_penalty) beams_in_batch_are_full = tf.broadcast_to( tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) ) & (early_stopping is True) @@ -2546,6 +2559,7 @@ def beam_search_body_fn( next_scores, next_beam_indices, next_is_sent_finished, + decoder_prompt_len, next_model_kwargs, ) @@ -2560,6 +2574,7 @@ def beam_search_body_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, model_kwargs, ) = beam_search_body_fn( cur_len, @@ -2570,12 +2585,13 @@ def beam_search_body_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, model_kwargs, ) # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # NOT yield EOS token though) - maximum_iterations = max_length - cur_len + maximum_iterations = max_length - cur_len - decoder_prompt_len ( cur_len, running_sequences, @@ -2585,6 +2601,7 @@ def beam_search_body_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, _, ) = tf.while_loop( beam_search_cond_fn, @@ -2598,11 +2615,14 @@ def beam_search_body_fn( scores, beam_indices, is_sent_finished, + decoder_prompt_len, model_kwargs, ), maximum_iterations=maximum_iterations, ) + print("decoder_prompt_len: {}".format(decoder_prompt_len)) + # 6. prepare outputs # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return # running sequences for that batch item. @@ -2621,7 +2641,7 @@ def beam_search_body_fn( if not use_xla: # Cut for backward compatibility - sequences = sequences[:, :cur_len] + sequences = sequences[:, : cur_len + decoder_prompt_len] beam_indices = beam_indices[:, :cur_len] if return_dict_in_generate: