diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index fdaeff14d78867..6b71671e14c852 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -133,9 +133,12 @@ def _pad_to_max_length( padding="longest", bos_token_tensor=None, cut_off_length=None, + return_token_timestamps=False, + force_unique_generate_call=False, ): max_total_length = 0 sequences = [] + token_timestamps_list = [] if padding_side not in ["right", "left"]: raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}") @@ -145,31 +148,74 @@ def _pad_to_max_length( elif padding == "max_length" and cut_off_length is None: raise ValueError("`cut_off_length` must be specified when `padding='max_length'`") + if force_unique_generate_call: + sequences_list = [] + timestamps_list = [] + for segments in current_segments: + result = segments[0]["result"] + sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"]) + if return_token_timestamps: + timestamps_list.append(result["token_timestamps"]) + + sequences = torch.stack(sequences_list, dim=0) + if return_token_timestamps: + token_timestamps = torch.stack(timestamps_list, dim=0) + return sequences, token_timestamps + return sequences + for current_segment_list in current_segments: if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + if return_token_timestamps: + token_timestamps = torch.cat( + [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list], + dim=-1, + ) if cut_off_length is not None: sequence = sequence[-cut_off_length:] + if return_token_timestamps: + token_timestamps = token_timestamps[-cut_off_length:] if bos_token_tensor is not None: sequence = torch.cat([bos_token_tensor, sequence]) - + if return_token_timestamps: + token_timestamps = torch.cat( + [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps] + ) sequences.append(sequence) + if return_token_timestamps: + token_timestamps_list.append(token_timestamps) max_total_length = max(max_total_length, len(sequences[-1])) elif bos_token_tensor is not None: sequences.append(bos_token_tensor) + if return_token_timestamps: + token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0) else: sequences.append(torch.tensor([], device=device)) + if return_token_timestamps: + token_timestamps_list.append(torch.tensor([], device=device)) max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length for i in range(len(current_segments)): pad_length = max_total_length - len(sequences[i]) pad = (0, pad_length) if padding_side == "right" else (pad_length, 0) + sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) + if return_token_timestamps: + token_timestamps_list[i] = F.pad( + token_timestamps_list[i], + pad=pad, + value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0, + ) sequences = torch.stack(sequences, dim=0) - return sequences + + if return_token_timestamps: + token_timestamps = torch.stack(token_timestamps_list, dim=0) + return sequences, token_timestamps + else: + return sequences class WhisperGenerationMixin(GenerationMixin): @@ -312,6 +358,7 @@ def generate( return_token_timestamps: Optional[bool] = None, return_segments: bool = False, return_dict_in_generate: Optional[bool] = None, + force_unique_generate_call: Optional[bool] = None, **kwargs, ): """ @@ -432,27 +479,39 @@ def generate( Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when `return_segments` is set True. In this case the generation outputs of each segment is added to each segment. + force_unique_generate_call (`bool`, *optional*): + Whether to force a unique call to the underlying GenerationMixin's generate method. This is useful for assisted decoding and testing purposes to ensure + that only one call to generate is made and therefore decoder input token ids and eos token ids are returned. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - Return: - [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`. + [`~utils.ModelOutput`] or `Dict[str, Any]` or `torch.LongTensor`: - If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned. + A: + - [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id. + - `Dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`. + - `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id. - else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: + The possible [`~utils.ModelOutput`] types are: + - [`~utils.GenerateEncoderDecoderOutput`] + - [`~utils.GenerateBeamEncoderDecoderOutput`] - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] + `segments` is a list of lists (one list per batch element) of `segment`. + A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`. + - `start`: the start timestamp of the segment. + - `end`: the end timestamp of the segment. + - `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id. + - `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's `generate` (present in `result`). + - `result`: the result of the underlying call to GenerationMixin's `generate`. - else only the generated output sequence ids are returned. + When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's `generate`, with outputs stored in `result` of each `segment`. Example: - - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. + - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. It is necessary to set `return_timestamps=True`. + Indeed, long-form transcription uses a sequential algorithm based on timestamps predictions, with heuristics like compression ratio threshold, log probability threshold and temperature fallback. This algorithm is described in the [the Whisper original paper](https://cdn.openai.com/papers/whisper.pdf), section *3.8. Long-form Transcription*. ```python >>> import torch @@ -483,7 +542,9 @@ def generate( " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile." ``` - - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate. + - *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities: + - `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's generate. + - `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription. ```python >>> import torch @@ -570,11 +631,21 @@ def generate( # 3. Retrieve logits processors device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device begin_index = init_tokens.shape[1] + num_beams = kwargs.get( + "num_beams", + generation_config.num_beams + if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None + else 1, + ) + if "assistant_model" in kwargs: + # speculative decoding: the model should be able to return eos token + generation_config.begin_suppress_tokens = None + logits_processor = self._retrieve_logit_processors( generation_config=generation_config, logits_processor=logits_processor, begin_index=begin_index, # begin index is index of first generated decoder token - num_beams=kwargs.get("num_beams", 1), + num_beams=num_beams, device=device, ) @@ -618,6 +689,19 @@ def generate( batch_size=cur_bsz, generation_config=generation_config, ) + # 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id + # we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id + if "assistant_model" in kwargs: + assistant_model = kwargs["assistant_model"] + assistant_model.generation_config.force_unique_generate_call = True + + if force_unique_generate_call is None: + if hasattr(generation_config, "force_unique_generate_call"): + force_unique_generate_call = generation_config.force_unique_generate_call + elif hasattr(self.generation_config, "force_unique_generate_call"): + force_unique_generate_call = self.generation_config.force_unique_generate_call + else: + force_unique_generate_call = False # 6 Transcribe audio until we reach the end of all input audios while (seek < max_frames).any(): @@ -729,14 +813,15 @@ def generate( prev_idx=prev_i, idx=i, return_token_timestamps=return_token_timestamps, + decoder_input_ids=decoder_input_ids, ) + seek[prev_i] += segment_offset + current_segments[prev_i] += segments - if is_shortform: - seek[prev_i] += max_frames[i] - else: - seek[prev_i] += segment_offset + if force_unique_generate_call: + break # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output @@ -746,51 +831,62 @@ def generate( else current_segments ) - sequences = _pad_to_max_length( - final_segments, generation_config.pad_token_id, device=self.device, padding_side="right" - ) - - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": final_segments} - - if is_shortform: - # add eos token: - if generation_config.max_new_tokens is None and generation_config.max_length is None: - eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id) - sequences = torch.cat([sequences, eos_tokens], dim=-1) + # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made, + # -> we can return a ModelOutput + # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments + if ( + return_dict_in_generate + and generation_config.return_dict_in_generate + and (force_unique_generate_call or not return_timestamps) + ): + # only one call to generate_with_fallback, we can return a ModelOutput + outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs) + if num_return_sequences > 1: + if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None: + outputs.encoder_attentions = tuple( + outputs.encoder_attentions[i][::num_return_sequences] + for i in range(len(outputs.encoder_attentions)) + ) + if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None: + outputs.encoder_hidden_states = tuple( + outputs.encoder_hidden_states[i][::num_return_sequences] + for i in range(len(outputs.encoder_hidden_states)) + ) + return outputs - if return_token_timestamps: - outputs = {} - outputs["sequences"] = sequences - outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0) - else: - outputs = sequences + padded_outputs = _pad_to_max_length( + current_segments=final_segments, + pad_token_id=generation_config.pad_token_id, + device=self.device, + padding_side="right", + return_token_timestamps=return_token_timestamps, + force_unique_generate_call=force_unique_generate_call, + ) - if return_dict_in_generate and generation_config.return_dict_in_generate: - dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs) + if return_dict_in_generate and generation_config.return_dict_in_generate: + logger.warning_once( + "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`." + ) + return_segments = True + elif not return_segments and not return_token_timestamps: + return padded_outputs - if num_return_sequences > 1: - if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None: - dict_outputs.encoder_attentions = tuple( - dict_outputs.encoder_attentions[i][::num_return_sequences] - for i in range(len(dict_outputs.encoder_attentions)) - ) - if ( - hasattr(dict_outputs, "encoder_hidden_states") - and dict_outputs.encoder_hidden_states is not None - ): - dict_outputs.encoder_hidden_states = tuple( - dict_outputs.encoder_hidden_states[i][::num_return_sequences] - for i in range(len(dict_outputs.encoder_hidden_states)) - ) - if return_token_timestamps: - dict_outputs["token_timestamps"] = outputs["token_timestamps"] - return dict_outputs + if return_token_timestamps: + sequences, token_timestamps = padded_outputs + outputs = { + "sequences": sequences, + "token_timestamps": token_timestamps, + } + else: + sequences = padded_outputs + outputs = { + "sequences": sequences, + } - return outputs + if return_segments: + outputs["segments"] = final_segments - return sequences + return outputs def generate_with_fallback( self, @@ -886,22 +982,14 @@ def generate_with_fallback( new_decoder_attention_mask = [] for i, seek_sequence in enumerate(seek_sequences): - # make sure we cut a predicted EOS token if we are not finished with the generation yet - prev_i = batch_idx_map[fallback_index_map[i]] - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - - # remove eos token id - if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - if return_token_timestamps and not is_shortform: - seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1] - - # remove all padding tokens + # remove all padding tokens, except for the eos token if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - if return_token_timestamps and not is_shortform: - seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] + if generation_config.pad_token_id == generation_config.eos_token_id: + # we do not remove the eos token id since it is needed for avg logprob calculation in _need_fallback + num_paddings -= 1 + if num_paddings != 0: + seek_sequence = seek_sequence[:-num_paddings] # check which sequences in batch need fallback & which should be skipped needs_fallback[i], should_skip[i] = self._need_fallback( @@ -914,6 +1002,10 @@ def generate_with_fallback( temperature, ) + # remove eos token + if seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] is_low_temperature = temperature is None or temperature < 0.5 @@ -956,14 +1048,19 @@ def _prepare_segments(prompt_ids, batch_size, generation_config): return current_segments def _postprocess_outputs( - self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform + self, + seek_outputs, + decoder_input_ids, + return_token_timestamps, + generation_config, + is_shortform, ): # remove all previously passed decoder input ids - start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0) + # should happen only if it is the first generated segment + start_idx = decoder_input_ids.shape[-1] if isinstance(seek_outputs, torch.Tensor): - seek_outputs = seek_outputs[:, start_idx:] - return seek_outputs, seek_outputs + return seek_outputs[:, start_idx:], seek_outputs if return_token_timestamps and hasattr(generation_config, "alignment_heads"): num_frames = getattr(generation_config, "num_frames", None) @@ -973,9 +1070,6 @@ def _postprocess_outputs( num_frames=num_frames, num_input_ids=decoder_input_ids.shape[-1], ) - seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:] - - seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:] def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None): if beam_indices is not None and key == "scores": @@ -1011,7 +1105,7 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None return values[batch_idx].cpu() - sequence_tokens = seek_outputs["sequences"] + sequence_tokens = seek_outputs["sequences"][:, start_idx:] seek_outputs = [ { k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices")) @@ -1026,7 +1120,7 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method outputs = {} for key in seek_outputs[0].keys(): - if key in ["sequences", "beam_indices"]: + if key in ["sequences", "beam_indices", "token_timestamps"]: outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device) elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]: outputs[key] = tuple( @@ -1057,6 +1151,10 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): else: outputs[key] = None + token_timestamps = outputs.get("token_timestamps", None) + if token_timestamps is not None: + model_output_type = dict + return model_output_type(**outputs) def _need_fallback( @@ -1083,7 +1181,9 @@ def _need_fallback( else: scores = seek_outputs[index]["scores"] logprobs = self._retrieve_avg_logprobs( - scores, seek_sequence, generation_config.eos_token_id, temperature + scores, + seek_sequence, + temperature, ) if logprobs < generation_config.logprob_threshold: @@ -1179,13 +1279,6 @@ def _maybe_warn_unused_inputs( if no_speech_threshold is not None: logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}")) - # when passing temperature as a list it cannot just be ignored => throw error in this case - if isinstance(temperature, (list, tuple)): - raise ValueError( - f"Audio input consists of only {total_input_frames}. Short-form transcription is activated." - f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation." - ) - @staticmethod def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config): if return_dict_in_generate is None: @@ -1768,7 +1861,7 @@ def _retrieve_compression_ratio(tokens, vocab_size): return compression_ratio @staticmethod - def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): + def _retrieve_avg_logprobs(scores, tokens, temperature): rescale_temperature = temperature if temperature > 0.0 else 1 scores = torch.stack(scores).to(tokens.device) @@ -1780,10 +1873,10 @@ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) # retrieve logprob of selected tokens and sum - sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) - length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0] + # don't remove the eos token logprob! it counts in avg_logprob calculation in the original implementation + sum_logprobs = sum(logprobs[i][tokens[i]] for i in range(logprobs.shape[0])) - avg_logprobs = sum_logprobs / (length + 1) + avg_logprobs = sum_logprobs / len(tokens) return avg_logprobs @staticmethod @@ -1799,6 +1892,7 @@ def _retrieve_segment( prev_idx, idx, return_token_timestamps, + decoder_input_ids, ): # find the predicted "end of segment" predictions of Whisper # "end of segment" predictions occur whenever Whisper predicts a timestamp token @@ -1807,6 +1901,7 @@ def _retrieve_segment( timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices.add_(1) token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else [] + idx_offset = decoder_input_ids.shape[-1] device = seek_sequence.device # If whisper predicted a "end of segment" via a timestep token, let's go ever each @@ -1838,12 +1933,13 @@ def _retrieve_segment( + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision, "tokens": sliced_tokens, + "idxs": (idx_offset + last_slice, idx_offset + current_slice), "result": seek_outputs[idx], } ) if return_token_timestamps: segments[-1]["token_timestamps"] = ( - token_timestamps[last_slice:current_slice] + time_offset[prev_idx] + token_timestamps[idx_offset + last_slice : idx_offset + current_slice] + time_offset[prev_idx] ) last_slice = current_slice @@ -1871,11 +1967,14 @@ def _retrieve_segment( "start": time_offset[prev_idx], "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, "tokens": seek_sequence, + "idxs": (idx_offset, idx_offset + len(seek_sequence)), "result": seek_outputs[idx], } ] if return_token_timestamps: - segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx] + segments[-1]["token_timestamps"] = ( + token_timestamps[idx_offset : idx_offset + len(seek_sequence)] + time_offset[prev_idx] + ) segment_offset = seek_num_frames[prev_idx] return segments, segment_offset diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 73303e374c8484..504b6174fc52ad 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -17,14 +17,22 @@ from __future__ import annotations import inspect +import os import tempfile import traceback import unittest import numpy as np -from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor -from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow +from transformers import GenerationConfig, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor +from transformers.testing_utils import ( + is_tf_available, + require_read_token, + require_tf, + require_tokenizers, + run_test_in_subprocess, + slow, +) from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available @@ -749,7 +757,9 @@ def _test_large_generation(in_queue, out_queue, timeout): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -772,13 +782,29 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): processor = WhisperProcessor.from_pretrained("openai/whisper-large") model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") - ds = load_dataset("legacy-datasets/common_voice", "ja", split="test", streaming=True, trust_remote_code=True) + # update generation config + generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") + + token = os.getenv("HF_HUB_READ_TOKEN", True) + ds = load_dataset( + "mozilla-foundation/common_voice_6_1", + "ja", + split="test", + streaming=True, + trust_remote_code=True, + token=token, + ) ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) input_speech = next(iter(ds))["audio"]["array"] input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, + language="<|ja|>", + task="transcribe", + generation_config=generation_config, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -786,7 +812,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, + language="<|en|>", + task="transcribe", + generation_config=generation_config, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -794,7 +825,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate" + input_features, + do_sample=False, + max_length=20, + language="<|ja|>", + task="translate", + generation_config=generation_config, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -825,10 +861,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): # fmt: off EXPECTED_IDS = [ - [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], - [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], - [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], - [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + [50258, 50259, 50359, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404], + [50258, 50259, 50359, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257], + [50258, 50259, 50359, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904], + [50258, 50259, 50359, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439] ] # fmt: on @@ -836,10 +872,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): # fmt: off EXPECTED_TRANSCRIPT = [ - " Mr. Quilter is the apostle of the middle classes and we are glad to", + " Mr. Quilter is the apostle of the middle classes and we are glad", " Nor is Mr. Quilter's manner less interesting than his matter.", - " He tells us that at this festive season of the year, with Christmas and roast beef", - " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," + " He tells us that at this festive season of the year, with Christmas and roast", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all" ] # fmt: on @@ -1009,6 +1045,7 @@ def test_large_generation(self): run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None) @slow + @require_read_token def test_large_generation_multilingual(self): run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index faab43854cce11..2eff406a3b56fc 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -445,6 +445,11 @@ def setUp(self): self.config_tester = ConfigTester(self, config_class=WhisperConfig) self.maxDiff = 3000 + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size) + inputs_dict["force_unique_generate_call"] = True + return config, inputs_dict + def test_config(self): self.config_tester.run_common_tests() @@ -1891,8 +1896,8 @@ def test_large_batched_generation_multilingual(self): "ja", split="test", streaming=True, - token=token, trust_remote_code=True, + token=token, ) ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) @@ -2144,11 +2149,16 @@ def test_small_longform_timestamps_generation(self): }, { "text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and", - "timestamp": (39.80, 45.36), + # "timestamp": (39.80, 45.36), + # above is the expected output on A100. + # on CI T4s, due to sligth difference in floating points operations, expected is below + "timestamp": (39.80, 45.38), }, { "text": " can discover in it but little of rocky Ithaca.", - "timestamp": (45.36, 49.0), + # "timestamp": (45.36, 49.0), + # see above + "timestamp": (45.38, 49.0), }, { "text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles", @@ -2275,20 +2285,20 @@ def test_tiny_token_timestamp_generation(self): # fmt: off EXPECTED_OUTPUT = torch.tensor([ - [0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400], - [0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400], + [0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200], + [0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000], [0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800], - [0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600] + [0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200] ]) # fmt: on self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT)) @slow - def test_large_token_timestamp_generation(self): + def test_small_token_timestamp_generation(self): set_seed(0) - processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") - model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + processor = WhisperProcessor.from_pretrained("openai/whisper-small") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") model.to(torch_device) input_speech = self._load_datasamples(4) @@ -2305,10 +2315,10 @@ def test_large_token_timestamp_generation(self): # fmt: off EXPECTED_OUTPUT = torch.tensor([ - [0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], - [0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], - [0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000], - [0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800] + [0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], + [0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], + [0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600], + [0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800] ]) # fmt: on @@ -3331,6 +3341,7 @@ def test_tiny_static_generation_long_form(self): # only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned) torch._dynamo.config.cache_size_limit = 4 + torch._dynamo.reset() processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")