diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 91812155c5719b..8012c3c1bbfcfb 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1000,13 +1000,15 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method outputs = {} for key in seek_outputs[0].keys(): - if key == "sequences": + if key in ["sequences", "beam_indices"]: outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device) - if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]: + elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]: outputs[key] = tuple( torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key])) ) - if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]: + elif key == "sequences_scores": + outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device) + elif key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]: outputs[key] = tuple( tuple( torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device) @@ -1014,7 +1016,7 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): ) for i in range(len(seek_outputs[0][key])) ) - if key == "past_key_values": + elif key == "past_key_values": past_key_value_type = kwargs.get("past_key_values") if seek_outputs[0][key] is not None: outputs[key] = tuple( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ebc9ce5ec358eb..f7ac2bc12eb252 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -529,6 +529,25 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] + def test_beam_search_output(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + model = WhisperForConditionalGeneration(config).to(torch_device).eval() + + input_features = input_dict["input_features"] + + # Perform beam search + output = model.generate( + input_features, num_beams=3, num_return_sequences=3, return_dict_in_generate=True, output_scores=True + ) + + # Check if beam_indices and sequences_scores are in the output + self.assertIn("beam_indices", output, "beam_indices not found in the output") + self.assertIn("sequences_scores", output, "sequences_scores not found in the output") + + # Validate the shapes of the beam_indices and sequences_scores + self.assertEqual(output.beam_indices.shape[0], input_features.shape[0] * 3) + self.assertEqual(output.sequences_scores.shape[0], input_features.shape[0] * 3) + # training is not supported yet @unittest.skip(reason="Training is not supported yet") def test_training(self):