diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index c45fffb984b113..0d6addb5631bec 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -16,7 +16,7 @@ import math import warnings import zlib -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -174,6 +174,8 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) + weight_length = None + if "beam_indices" in generate_outputs: # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths # since the beam search strategy chooses the most probable sequences at the end of the search. @@ -195,7 +197,9 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec dim=2, ) - timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32) + # make sure timestamps are as long as weights + input_length = weight_length or cross_attentions[0].shape[2] + timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1] batch_size = timestamps.shape[0] if num_frames is not None: @@ -260,6 +264,7 @@ def generate( language: Optional[str] = None, is_multilingual: Optional[bool] = None, prompt_ids: Optional[torch.Tensor] = None, + prompt_condition_type: Optional[str] = None, # first-segment, all-segments condition_on_prev_tokens: Optional[bool] = None, temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, @@ -333,6 +338,9 @@ def generate( provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + prompt_condition_type (`str`, *optional*): + Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'. + Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible. condition_on_prev_tokens (`bool`, *optional*): Only relevant for long-form transcription. Whether to condition each segment on the previous segment. As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve @@ -474,7 +482,7 @@ def generate( # 2. set global generate variables input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] num_segment_frames = input_stride * self.config.max_source_positions - total_input_frames = self._retrieve_total_input_frames( + batch_size, total_input_frames = self._retrieve_total_input_frames( input_features=input_features, input_stride=input_stride, kwargs=kwargs ) is_shortform = total_input_frames <= num_segment_frames @@ -505,15 +513,6 @@ def generate( self._set_language_and_task( language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config ) - # pass self.config for backward compatibility - self._set_forced_decoder_ids( - task=task, - language=language, - prompt_ids=prompt_ids, - generation_config=generation_config, - config=self.config, - kwargs=kwargs, - ) self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs) self._set_num_frames( return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs @@ -525,12 +524,31 @@ def generate( no_speech_threshold=no_speech_threshold, condition_on_prev_tokens=condition_on_prev_tokens, ) + self._set_prompt_condition_type( + generation_config=generation_config, + prompt_condition_type=prompt_condition_type, + ) + + # pass self.config for backward compatibility + init_tokens = self._retrieve_init_tokens( + input_features, + generation_config=generation_config, + config=self.config, + num_segment_frames=num_segment_frames, + kwargs=kwargs, + ) + # TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead + # This function should be be removed in v4.39 + self._check_decoder_input_ids( + prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs + ) - # 4. Retrieve logits processors + # 3. Retrieve logits processors + begin_index = len(init_tokens) logits_processor = self._retrieve_logit_processors( generation_config=generation_config, logits_processor=logits_processor, - no_speech_threshold=no_speech_threshold, + begin_index=begin_index, # begin index is index of first generated decoder token is_shortform=is_shortform, num_beams=kwargs.get("num_beams", 1), ) @@ -540,6 +558,27 @@ def generate( if temperature is not None: kwargs["temperature"] = temperature + decoder_input_ids = kwargs.pop("decoder_input_ids", None) + if decoder_input_ids is None: + one_tensor = torch.ones((batch_size, 1), device=self.device, dtype=torch.long) + decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + + if prompt_ids is not None: + decoder_input_ids = torch.cat( + [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1 + ) + + if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions: + max_new_tokens = kwargs.get("max_new_tokens", 0) + raise ValueError( + f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " + f"is {max_new_tokens}. Thus, the combined length of " + f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " + f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less than {self.config.max_target_positions}." + ) + outputs = super().generate( input_features, generation_config=generation_config, @@ -547,6 +586,7 @@ def generate( stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, + decoder_input_ids=decoder_input_ids, **kwargs, ) @@ -573,11 +613,15 @@ def generate( max_frames, seek = self._retrieve_max_frames_and_seek( batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames ) - init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) # 6.2 Preppare running variables, list for generation cur_bsz = batch_size - current_segments = [[] for _ in range(batch_size)] + current_segments = self._prepare_segments( + prompt_ids=prompt_ids, + batch_size=batch_size, + generation_config=generation_config, + ) + batch_idx_map = list(range(batch_size)) do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] @@ -617,6 +661,7 @@ def generate( current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, + prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, device=segment_input.device, @@ -682,11 +727,16 @@ def generate( # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + final_segments = ( + [x[1:] for x in current_segments] + if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment") + else current_segments + ) + sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right") # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. if return_segments: - return {"sequences": sequences, "segments": current_segments} + return {"sequences": sequences, "segments": final_segments} return sequences @@ -721,7 +771,8 @@ def generate_with_fallback( for fallback_idx, temperature in enumerate(temperatures): generation_config.do_sample = temperature is not None and temperature > 0.0 - generation_config.temperature = temperature + + generation_config.temperature = temperature if generation_config.do_sample else 1.0 generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 seek_outputs = super().generate( @@ -736,13 +787,13 @@ def generate_with_fallback( ) # post-process sequence tokens and outputs to be in list form - sequence_tokens, seek_outputs = self._postprocess_outputs( - seek_outputs, return_token_timestamps, generation_config + seek_sequences, seek_outputs = self._postprocess_outputs( + seek_outputs=seek_outputs, + decoder_input_ids=decoder_input_ids, + return_token_timestamps=return_token_timestamps, + generation_config=generation_config, ) - # remove all previously passed decoder input ids - seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :] - # 6.7 Extract cut sequences from every sequence and check if fallback should be applied # Loop over each decoded audio individually as each decoding can be of a different length new_fallback_index_map = [] @@ -777,8 +828,9 @@ def generate_with_fallback( seek_sequence_list[fallback_index_map[i]] = seek_sequence seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + is_low_temperature = temperature is None or temperature < 0.5 do_condition_on_prev_tokens[fallback_index_map[i]] = ( - generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5 + generation_config.condition_on_prev_tokens and is_low_temperature ) if needs_fallback[i]: @@ -804,30 +856,44 @@ def generate_with_fallback( return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens - def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config): + @staticmethod + def _prepare_segments(prompt_ids, batch_size, generation_config): + if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment": + prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None) + prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids + current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)] + else: + current_segments = [[] for _ in range(batch_size)] + + return current_segments + + def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config): + # remove all previously passed decoder input ids + if isinstance(seek_outputs, torch.Tensor): + seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :] + return seek_outputs, seek_outputs + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): num_frames = getattr(generation_config, "num_frames", None) seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs, generation_config.alignment_heads, num_frames=num_frames ) - if generation_config.return_dict_in_generate: + seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] - def split_by_batch_index(values, key, batch_idx): - if key == "scores": - return [v[batch_idx].cpu() for v in values] - if key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - return values[batch_idx].cpu() + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return [v[batch_idx].cpu() for v in values] + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() - sequence_tokens = seek_outputs["sequences"] - seek_outputs = [ - {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} - for i in range(sequence_tokens.shape[0]) - ] - else: - sequence_tokens = seek_outputs + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [ + {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} + for i in range(sequence_tokens.shape[0]) + ] return sequence_tokens, seek_outputs @@ -884,7 +950,7 @@ def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_id @staticmethod def _retrieve_total_input_frames(input_features, input_stride, kwargs): if input_features is not None: - return input_features.shape[-1] + return input_features.shape[0], input_features.shape[-1] if "encoder_outputs" in kwargs: encoder_outputs_shape = ( @@ -892,7 +958,7 @@ def _retrieve_total_input_frames(input_features, input_stride, kwargs): if isinstance(kwargs["encoder_outputs"], BaseModelOutput) else kwargs["encoder_outputs"].shape ) - return encoder_outputs_shape[1] * input_stride + return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") @@ -950,34 +1016,24 @@ def _set_return_outputs( @staticmethod def _set_return_timestamps(return_timestamps, is_shortform, generation_config): - if return_timestamps is True: - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - ) - generation_config.return_timestamps = True - elif not is_shortform: + if not is_shortform: if return_timestamps is False: raise ValueError( "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." ) - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the generation config to have `no_timestamps_token_id` correctly. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - "or make sure to pass no more than 3000 mel input features." - ) - logger.info("Setting `return_timestamps=True` for long-form generation.") - generation_config.return_timestamps = True - else: - generation_config.return_timestamps = False + return_timestamps = True + + if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + + generation_config.return_timestamps = return_timestamps @staticmethod def _set_language_and_task(language, task, is_multilingual, generation_config): @@ -1016,94 +1072,221 @@ def _set_language_and_task(language, task, is_multilingual, generation_config): ) generation_config.task = task - @staticmethod - def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs): - forced_decoder_ids = None - # Legacy code for backward compatibility - if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: - forced_decoder_ids = config.forced_decoder_ids + def _retrieve_init_tokens(self, input_features, generation_config, config, num_segment_frames, kwargs): + def replace_or_add(lst: List[int], num: int, itr: Iterator[int]): + """short function to replace num with a itr in lst""" + found = any(i in lst for i in itr) + if found: + lst = [num if i in itr else i for i in lst] + else: + lst.append(num) + return lst + + task = getattr(generation_config, "task", None) + language = getattr(generation_config, "language", None) + + if kwargs.get("forced_decoder_ids", None) is not None: + forced_decoder_ids = kwargs["forced_decoder_ids"] elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None: forced_decoder_ids = generation_config.forced_decoder_ids + + if language is None and task is None and forced_decoder_ids[0][1] is None: + logger.warning_once( + "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English." + "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`." + ) + elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: + forced_decoder_ids = config.forced_decoder_ids else: - forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) - - if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): - forced_decoder_ids = [] - if hasattr(generation_config, "language"): - if generation_config.language in generation_config.lang_to_id.keys(): - language_token = generation_config.language - elif generation_config.language in TO_LANGUAGE_CODE.keys(): - language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" - elif generation_config.language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{generation_config.language}|>" - else: - is_language_code = len(generation_config.language) == 2 - raise ValueError( - f"Unsupported language: {generation_config.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - if language_token not in generation_config.lang_to_id: - raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" - ) - forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + forced_decoder_ids = None + + if forced_decoder_ids is not None and task is not None: + logger.info( + f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}." + ) + forced_decoder_ids = None + elif forced_decoder_ids is not None and language is not None: + logger.info( + f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}." + ) + forced_decoder_ids = None + + init_tokens = [generation_config.decoder_start_token_id] + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + # TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39 + if len(forced_decoder_ids) > 0: + warnings.warn( + f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.", + FutureWarning, + ) + + # TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39 + generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None + + is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None) + if language is not None: + if language in generation_config.lang_to_id.keys(): + language_token = language + elif language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[language]}|>" + elif language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{language}|>" else: - forced_decoder_ids.append((1, None)) # automatically detect the language - - if hasattr(generation_config, "task"): - if generation_config.task in TASK_IDS: - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - raise ValueError( - f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif hasattr(generation_config, "task_to_id"): - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe - if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if forced_decoder_ids is not None: - generation_config.forced_decoder_ids = forced_decoder_ids - - if prompt_ids is not None: - if kwargs.get("decoder_start_token_id") is not None: + is_language_code = len(language) == 2 raise ValueError( - "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + f"Unsupported language: {language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." ) - prompt_ids = prompt_ids.tolist() - decoder_start_token_id, *text_prompt_ids = prompt_ids - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :] - # Set the decoder_start_token_id to <|startofprev|> - kwargs.update({"decoder_start_token_id": decoder_start_token_id}) - - # If the user passes `max_new_tokens`, increase its number to account for the prompt - if kwargs.get("max_new_tokens", None) is not None: - kwargs["max_new_tokens"] += len(text_prompt_ids) - if kwargs["max_new_tokens"] >= config.max_target_positions: - raise ValueError( - f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " - f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " - f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " - f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less that {config.max_target_positions}." - ) - - # Reformat the forced_decoder_ids to incorporate the prompt - non_prompt_forced_decoder_ids = ( - kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + if language_token not in generation_config.lang_to_id: + raise ValueError( + f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." + "(You should just add it to the generation config)" + ) + + lang_id = generation_config.lang_to_id[language_token] + + # if language is defined it'll overwrite language ids that might have already been defined via the generation_config + replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values()) + elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined: + # language is not defined or intentially set to `None` to trigger language detection + lang_ids = self.detect_language( + input_features=input_features, + encoder_outputs=kwargs.get("encoder_outputs", None), + generation_config=generation_config, + num_segment_frames=num_segment_frames, + ) + + if torch.unique(lang_ids).shape[0] > 1: + raise ValueError( + "Multiple languages detected when trying to predict the most likely target language for transcription. It is currently not supported to transcribe to different languages in a single batch. Please make sure to either force a single language by passing `language='...'` or make sure all input audio is of the same language." + ) + + lang_id = lang_ids[0].item() + + # append or replace lang_id to init_tokens + if len(init_tokens) > 1: + init_tokens[1] = lang_id + else: + init_tokens.append(lang_id) + + if task is not None: + if task in TASK_IDS: + init_tokens.append(generation_config.task_to_id[generation_config.task]) + task_id = generation_config.task_to_id[generation_config.task] + + # if task is defined it'll overwrite task ids that might have already been defined via the generation_config + replace_or_add(init_tokens, task_id, generation_config.task_to_id.values()) + else: + raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`") + elif language is not None and hasattr(generation_config, "task_to_id"): + # if language is defined, but no task id is in `init_tokens`, default to transcribe + if not any(i in init_tokens for i in generation_config.task_to_id.values()): + init_tokens.append(generation_config.task_to_id["transcribe"]) + + if ( + not generation_config.return_timestamps + and hasattr(generation_config, "no_timestamps_token_id") + and init_tokens[-1] != generation_config.no_timestamps_token_id + ): + init_tokens.append(generation_config.no_timestamps_token_id) + elif generation_config.return_timestamps and init_tokens[-1] == generation_config.no_timestamps_token_id: + logger.info( + "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`." + ) + init_tokens = init_tokens[:-1] + + # let's make sure we don't pass `None` tokens as prompt tokens + init_tokens = [t for t in init_tokens if t is not None] + + return init_tokens + + def detect_language( + self, + input_features: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None, + generation_config: Optional[GenerationConfig] = None, + num_segment_frames: int = 3000, + ) -> torch.Tensor: + """ + Detects language from log-mel input features or encoder_outputs + + Parameters: + input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*): + Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + num_segment_frames (`int`, defaults to 3000): + The number of log-mel frames the model expects + + Return: + A `torch.LongTensor` representing the detected language ids. + """ + if input_features is None and encoder_outputs is None: + raise ValueError("You have to specify either `input_features` or `encoder_outputs`") + elif input_features is not None and encoder_outputs is not None: + raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!") + elif input_features is not None: + inputs = {"input_features": input_features[:, :, :num_segment_frames]} + batch_size = input_features.shape[0] + elif encoder_outputs is not None: + inputs = {"encoder_outputs": encoder_outputs} + batch_size = ( + encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0] + ) + + generation_config = generation_config or self.generation_config + decoder_input_ids = ( + torch.ones((batch_size, 1), device=self.device, dtype=torch.long) + * generation_config.decoder_start_token_id + ) + + with torch.no_grad(): + logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1] + + non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool) + non_lang_mask[list(generation_config.lang_to_id.values())] = False + + logits[:, non_lang_mask] = -np.inf + + lang_ids = logits.argmax(-1) + + return lang_ids + + @staticmethod + def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs): + decoder_input_ids = kwargs.get("decoder_input_ids", None) + if prompt_ids is not None and decoder_input_ids is not None: + raise ValueError( + f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it." + ) + elif decoder_input_ids is not None and not is_shortform: + raise ValueError( + f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead." + ) + elif decoder_input_ids is not None and is_shortform: + warnings.warn( + f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.", + FutureWarning, ) - forced_decoder_ids = [ - *text_prompt_ids, - generation_config.decoder_start_token_id, - *[token for _, token in non_prompt_forced_decoder_ids], - ] - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] - generation_config.forced_decoder_ids = forced_decoder_ids @staticmethod def _set_token_ids(generation_config, config, kwargs): @@ -1162,6 +1345,25 @@ def _set_thresholds_and_condition( else getattr(generation_config, "condition_on_prev_tokens", None) ) + @staticmethod + def _set_prompt_condition_type(generation_config, prompt_condition_type): + allowed_cond_types = ["first-segment", "all-segments"] + + # default to "first-segment" + prompt_condition_type = prompt_condition_type or allowed_cond_types[0] + + if prompt_condition_type not in allowed_cond_types: + raise ValueError( + f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}" + ) + + if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments": + raise ValueError( + "Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`." + ) + + generation_config.prompt_condition_type = prompt_condition_type + @staticmethod def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): condition_on_prev_tokens = ( @@ -1175,7 +1377,7 @@ def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): if batch_size > 1 and attention_mask is None: raise ValueError( - "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " ) elif batch_size > 1: max_frames = attention_mask.sum(-1).cpu().to(torch.long) @@ -1186,37 +1388,7 @@ def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames return max_frames, seek - @staticmethod - def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): - init_tokens = [generation_config.decoder_start_token_id] - forced_decoder_ids = generation_config.forced_decoder_ids - if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: - i = 1 - while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[0][1]] - forced_decoder_ids = forced_decoder_ids[1:] - i += 1 - - forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None - generation_config.forced_decoder_ids = forced_decoder_ids - - return init_tokens - - def _retrieve_logit_processors( - self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams - ): - forced_decoder_ids = generation_config.forced_decoder_ids - if generation_config.return_timestamps is True: - last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None - if last_forced_decoder_ids == generation_config.no_timestamps_token_id: - # remove no_timestamp to be forcefully generated if we want to return timestamps - # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly - forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None - # Make sure that if list is empty we set it to None - generation_config.forced_decoder_ids = forced_decoder_ids - - begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1 - + def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams): if generation_config.return_timestamps is True: timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) logits_processor = ( @@ -1243,7 +1415,7 @@ def _retrieve_logit_processors( ) generation_config.begin_suppress_tokens = None - if no_speech_threshold is not None and not is_shortform: + if generation_config.no_speech_threshold is not None and not is_shortform: no_speech_detector = WhisperNoSpeechDetection( no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, @@ -1256,11 +1428,12 @@ def _retrieve_logit_processors( if is_shortform and generation_config.forced_decoder_ids is not None: forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) - # TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after + # It's important that the `forced_tokens_proc` processor is appended after # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf # which would lead to unexpected behavior # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead # initialize all of them as `decoder_input_ids`. + # TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore. logits_processor = ( [forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc] ) @@ -1310,6 +1483,7 @@ def _prepare_decoder_input_ids( current_segments, batch_idx_map, do_condition_on_prev_tokens, + prompt_ids, generation_config, config, device, @@ -1328,19 +1502,27 @@ def _prepare_decoder_input_ids( if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] - prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text - bos_token_tensor = prev_start_of_text * one_tensor[0] + if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": + prev_ids = prompt_ids + else: + prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None + prev_tokens = _pad_to_max_length( active_segments, generation_config.pad_token_id, padding="left", - bos_token_tensor=bos_token_tensor, + bos_token_tensor=prev_ids, cut_off_length=cut_off_length, ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id + elif prompt_ids is not None: + prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1) + decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) + # make sure `"decoder_attention_mask"` is not passed to forward + kwargs.pop("decoder_attention_mask", None) else: # make sure `"decoder_attention_mask"` is not passed to forward kwargs.pop("decoder_attention_mask", None) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 505d2e991033d8..1f92f1523dbbde 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -25,6 +25,7 @@ import numpy as np import pytest +from huggingface_hub import hf_hub_download import transformers from transformers import WhisperConfig @@ -38,7 +39,7 @@ slow, torch_device, ) -from transformers.utils import cached_property, is_flax_available, is_torch_available +from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available from transformers.utils.import_utils import is_datasets_available from ...generation.test_utils import GenerationTesterMixin @@ -142,6 +143,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +if is_torchaudio_available(): + import torchaudio + + if is_flax_available(): import jax.numpy as jnp @@ -1258,7 +1263,7 @@ def test_generate_with_prompt_ids_and_task_and_language(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) input_features = input_dict["input_features"] - prompt_ids = np.arange(5) + prompt_ids = torch.arange(5).to(torch_device) language = "<|de|>" task = "translate" lang_id = 6 @@ -1281,7 +1286,7 @@ def test_generate_with_prompt_ids_and_forced_decoder_ids(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) input_features = input_dict["input_features"] - prompt_ids = np.asarray(range(5)) + prompt_ids = torch.arange(5).to(torch_device) forced_decoder_ids = [(1, 6), (2, 7), (3, 8)] output = model.generate( @@ -1298,27 +1303,67 @@ def test_generate_with_prompt_ids_and_forced_decoder_ids(self): def test_generate_with_prompt_ids_max_length(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.max_target_positions = 5 + config.max_target_positions = 7 model = WhisperForConditionalGeneration(config).eval().to(torch_device) input_features = input_dict["input_features"] - prompt_ids = np.asarray(range(4)) - sliced_prompt_ids = prompt_ids[1:] - sliced_prompt_ids = sliced_prompt_ids[-config.max_target_positions // 2 - 1 :] - max_new_tokens = 5 + decoder_input_ids = torch.arange(5).to(torch_device) + prompt_ids = decoder_input_ids[:4] + max_new_tokens = 8 with self.assertRaisesRegex( ValueError, - f"The length of the sliced `prompt_ids` is {len(sliced_prompt_ids)}, and the `max_new_tokens` " - f"{max_new_tokens}. Thus, the combined length of the sliced `prompt_ids` and `max_new_tokens` is: " - f"{len(sliced_prompt_ids) + max_new_tokens}. This exceeds the `max_target_positions` of the Whisper model: " - f"{config.max_target_positions}. You should either reduce the length of your prompt, or reduce the " - f"value of `max_new_tokens`, so that their combined length is less that {config.max_target_positions}.", + f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " + f"is {max_new_tokens}. Thus, the combined length of " + f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " + f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less than {config.max_target_positions}.", ): model.generate(input_features, max_new_tokens=max_new_tokens, prompt_ids=prompt_ids) model.generate(input_features, max_new_tokens=1, prompt_ids=prompt_ids) + def test_generate_longform_with_prompt_ids(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = WhisperForConditionalGeneration(config).eval().to(torch_device) + + prompt_ids = torch.arange(5).to(torch_device) + model.generation_config.no_timestamps_token_id = 11 + model.generation_config.pad_token_id = 10 + + # make sure prompt token ids [0-9] can't be generated + model.generation_config.suppress_tokens = list(range(10)) + + input_features = input_dict["input_features"] + + language = "<|de|>" + lang_id = 6 + + input_features = input_features.repeat(1, 1, 50) + attention_mask = torch.ones_like(input_features, dtype=torch.long)[:, 0] + + for prompt_type in ["first-segment", "all-segments"]: + for task_id, task in enumerate(["translate", "transcribe"]): + task_id = 7 + task_id + + model.generation_config.__setattr__("lang_to_id", {language: lang_id}) + model.generation_config.__setattr__("task_to_id", {task: task_id}) + + output = model.generate( + input_features, + attention_mask=attention_mask, + prompt_condition_type=prompt_type, + max_new_tokens=5, + task=task, + language=language, + prompt_ids=prompt_ids, + condition_on_prev_tokens=True, + ) + for row in output.tolist(): + # make sure no token below 10 is in generated output => this means for long-form prompt ids should NOT be returned + assert not any(i in row for i in model.generation_config.suppress_tokens) + def _check_longform_generate_single_batch(self, condition_on_prev_tokens): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -1919,7 +1964,8 @@ def test_tiny_token_timestamp_batch_generation(self): num_return_sequences=num_return_sequences, ) - self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + # task id and lang id prompts should not have timestamp tokens + self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1]) self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) @@ -1967,13 +2013,110 @@ def test_generate_with_prompt_ids(self): input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) output_without_prompt = model.generate(input_features) - prompt_ids = processor.get_prompt_ids("Leighton") + prompt_ids = processor.get_prompt_ids("Leighton", return_tensors="pt").to(torch_device) output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids) expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" - self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt) - self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt) + + output_without_prompt = processor.decode(output_without_prompt[0]) + output_with_prompt = processor.decode(output_with_prompt[0]) + + self.assertEqual(output_without_prompt, expected_without_prompt) + self.assertEqual(output_with_prompt, expected_with_prompt) + + @slow + def test_language_detection(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + input_speech = self._load_datasamples(4)[-1:] + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) + + lang_id = model.detect_language(input_features)[0].item() + + ids_to_lang = {v: k for k, v in model.generation_config.lang_to_id.items()} + + assert ids_to_lang[lang_id] == "<|en|>" + + audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") + + raw_audio, sr = torchaudio.load(audio) + input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio).numpy() + + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) + + lang_id = model.detect_language(input_features)[0].item() + + assert ids_to_lang[lang_id] == "<|hi|>" + + @slow + def test_default_multilingual_transcription_short_form(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + + audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") + + raw_audio, sr = torchaudio.load(audio) + input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio).numpy() + + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) + + # model.generation_config.forced_decoder_ids defaults to [1, null] for lang_token + sequences = model.generate(input_features) + + transcription = processor.batch_decode(sequences, skip_special_tokens=False)[0] + + assert ( + transcription + == "<|startoftranscript|><|hi|><|transcribe|><|notimestamps|> Mirchi mein ki tene vibinda prajatiya hai<|endoftext|>" + ) + + # set forced_decoder_ids to English + model.generation_config.forced_decoder_ids[0][-1] = 50259 + + sequences = model.generate(input_features) + transcription = processor.batch_decode(sequences, skip_special_tokens=False)[0] + + assert ( + transcription + == "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> MIRCHI MET, which is the name of the Bible.<|endoftext|>" + ) + + @slow + def test_default_multilingual_transcription_long_form(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") + model.to(torch_device) + + audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") + + raw_audio, sr = torchaudio.load(audio) + input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio) + + input_speech = input_speech.repeat(1, 10).numpy() + input_features = processor( + input_speech, return_tensors="pt", padding="longest", truncation=False + ).input_features.to(torch_device) + + # model.generation_config.forced_decoder_ids defaults to [1, null] for lang_token + sequences = model.generate(input_features) + + transcription = processor.batch_decode(sequences)[0] + + assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?" + + # set forced_decoder_ids to English + model.generation_config.forced_decoder_ids[0][-1] = 50259 + + sequences = model.generate(input_features) + transcription = processor.batch_decode(sequences)[0] + + assert ( + transcription + == " How many different species are there in the chilli? How many different species are there in the chili?" + ) @slow def test_generate_with_prompt_ids_and_forced_decoder_ids(self): @@ -1986,7 +2129,7 @@ def test_generate_with_prompt_ids_and_forced_decoder_ids(self): language = "de" expected_tokens = [f"<|{task}|>", f"<|{language}|>"] prompt = "test prompt" - prompt_ids = processor.get_prompt_ids(prompt) + prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device) output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids) text = processor.decode(output[0]) @@ -2002,7 +2145,7 @@ def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self): input_speech = self._load_datasamples(1) input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) prompt = "test prompt" - prompt_ids = processor.get_prompt_ids(prompt) + prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device) model.generation_config.forced_decoder_ids = None model.config.forced_decoder_ids = None @@ -2033,7 +2176,9 @@ def test_speculative_decoding_distil(self): dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") sample = dataset[0]["audio"] - input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16) + input_features = ( + processor(sample["array"], return_tensors="pt").input_features.to(torch_device).to(torch.float16) + ) # warm up assisted decoding _ = model.generate(input_features, assistant_model=assistant_model) @@ -2081,7 +2226,9 @@ def test_speculative_decoding_non_distil(self): dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") sample = dataset[0]["audio"] - input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16) + input_features = ( + processor(sample["array"], return_tensors="pt").input_features.to(torch_device).to(torch.float16) + ) # warm up assisted decoding _ = model.generate(input_features, assistant_model=assistant_model) @@ -2116,7 +2263,7 @@ def test_whisper_longform_single_batch(self): processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) @@ -2124,7 +2271,7 @@ def test_whisper_longform_single_batch(self): input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[ "input_features" ] - input_features = input_features.to(device="cuda") + input_features = input_features.to(device=torch_device) result = model.generate(input_features, return_timestamps=True) decoded = processor.batch_decode(result, skip_special_tokens=True) @@ -2145,15 +2292,65 @@ def test_whisper_longform_single_batch(self): assert is_increasing + @slow + def test_whisper_longform_prompt_ids(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model = model.to(torch_device) + + prompt = "Mr. Kilter, Ruggedo." # let's force Mr. Quilter -> Mr. Kilter + prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device) + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) + + first_text = ds["validation"][0]["text"].lower() + last_text = ds["validation"][-1]["text"].lower() + + input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[ + "input_features" + ] + input_features = input_features.to(device=torch_device) + + result = model.generate( + input_features, + prompt_ids=prompt_ids, + return_timestamps=True, + prompt_condition_type="first-segment", + condition_on_prev_tokens=True, + ) + decoded_first_segment = processor.batch_decode(result, skip_special_tokens=True) + + result = model.generate( + input_features, + prompt_ids=prompt_ids, + return_timestamps=True, + prompt_condition_type="all-segments", + condition_on_prev_tokens=True, + ) + decoded_all_segments = processor.batch_decode(result, skip_special_tokens=True) + + # show that first segment has quilter and last segment has ruggedo + assert "quilter" in first_text + assert "ruggedo" in last_text + + # condition on first segment correctly changes to kilter in first segment, but does not transcribe "ruggedo" correctly + assert "kilter" in decoded_first_segment[0][: len(first_text)].lower() + assert "ruggedo" not in decoded_first_segment[0][-len(last_text) :].lower() + + # condition on all-segment correctly changes to kilter in first segment and correctly transcribes "ruggedo" + assert "kilter" in decoded_all_segments[0][: len(first_text)].lower() + assert "ruggedo" in decoded_all_segments[0][-len(last_text) :].lower() + @slow def test_whisper_longform_single_batch_prev_cond(self): # fmt: off - EXPECTED_TEXT = [""" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. When Mr. John Collier gives his sitter a cheerful slap in the back, before he says like a shampooer and a Turkish bath, next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. He tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in felicitous grace that many faces are feeling. Unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M. A. A man said to the universe, Sir, I exist. Sweat covered Breon's body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retroveilities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you're being a fool. But there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Your man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Breon's death was in some ways easier than defeat. Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that's rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gong, and pounded on it, just as we're good to be used to do. But no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong, and then sat in the throne, wearing Regidos discarded Ruby Crown, and holding in his hand to scepter, which Regidos had so often thrown at his head."""] + EXPECTED_TEXT = [""" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. When Mr. John Collier gives his sitter a cheerful slap in the back, before he says like a shampooer and a Turkish bath, next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. He tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in felicitous grace that many faces are feeling. Unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M. A. A man said to the universe, Sir, I exist. Sweat covered Breon's body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retroveilities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you're being a fool. But there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Your man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Breon's death was in some ways easier than defeat. Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that's rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it, just as we're good to be used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Regidos discarded Ruby crown, and holding in his hand to scepter which Regidos had so often thrown at his head."""] # fmt: on processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) @@ -2161,7 +2358,7 @@ def test_whisper_longform_single_batch_prev_cond(self): input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[ "input_features" ] - input_features = input_features.to(device="cuda") + input_features = input_features.to(device=torch_device) gen_kwargs = { "return_timestamps": True, @@ -2189,7 +2386,7 @@ def test_whisper_longform_multi_batch(self): processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) @@ -2202,7 +2399,7 @@ def test_whisper_longform_multi_batch(self): decoded_single = [] for audio in audios: inputs = processor(audio, return_tensors="pt", truncation=False) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) result = model.generate(**inputs, return_timestamps=True) decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) @@ -2210,7 +2407,7 @@ def test_whisper_longform_multi_batch(self): inputs = processor( audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True ) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) result = model.generate(**inputs, return_timestamps=True) decoded_all = processor.batch_decode(result, skip_special_tokens=True) @@ -2230,15 +2427,15 @@ def test_whisper_longform_multi_batch(self): @slow def test_whisper_longform_multi_batch_prev_cond(self): # fmt: off - EXPECTED_TEXT_1 = [" Mr. Quilters manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. The Nils, pictures are sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilters writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are of two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does get good. Mr. Quilters has missed his chance, for he has failed even to make himself the tougher of painting. My hair equal to M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment he wore. The cut on his chest still dripping blood. The ache of his overstrain dyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, the resoundance then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. Our red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were inexplicably linked into one. This strengthened enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our role. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you run into escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedsey thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Ruggano used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Ruggano's discarded ruby crown. And holding in his hand the scepter which Ruggano had so often thrown at his head."] + EXPECTED_TEXT_1 = [" Mr. Quilters manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. The Nils, pictures are sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilters writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are of two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does get good. Mr. Quilters has missed his chance, for he has failed even to make himself the tougher of painting. My hair equal to M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment he wore. The cut on his chest still dripping blood. The ache of his overstrain dyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance. And brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly. But that away, he'd be no fool. Out, the resoundance then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were inexplicably linked into one. This strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. In the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our role. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you run into escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedsey thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Ruggano used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Ruggano's discarded ruby crown. And holding in his hand the scepter which Ruggano had so often thrown at his head."] EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennials, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker"] - EXPECTED_TEXT_3 = [" gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating in its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky ithaka. Lennils, pictures, are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostoror. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and falseness graced that many phases of feeling, only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even the soaring arena around him with thousands of spectators were trivealed, not worth thinking about. His instant panic was followed by a small sharp, blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie sliding out on the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away, he'll be in the fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Whereas my brother now, in Quaragejjegi, in the metal forest. Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny remarked by the bad sea thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, a great Calico. Calico went to the big gong and pounded on it, just as we're good or used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown, and holding in his hand to scepter which reggos had so often thrown at his head."] - EXPECTED_TEXT_4 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like a shampooer in a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does, get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter, M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators were trivialities not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inexplicably linked into one. Just strengthed and enabled someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. She has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace and your friends are asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, in Quaragejji, in the metal forest? Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked a bit, see you thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as we're good we used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown and holding it his hand to scepter which reggo had so often thrown at his head."] + EXPECTED_TEXT_3 = [" gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating in its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky ithaka. Lennils, pictures, are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostoror. Near the fire, any ornaments spread brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many faces are feeling, only unfortunately his own work never does get good. Mr. Quilter has missed his chance. For he has failed even to make himself the tougher of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat covered Brienne's body trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered his muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding out on the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away here being a fool. Out, there is silence then, and still wondering, Brienne was once more asleep. 10 seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were anextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne's softly spoke the odd hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He said it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Brienne to long ago to send him away, but he would not do so. I also offered to help you brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired Shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to bed see you thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gone and pounded on it, just as we're good or used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gone and then sat in the throne, wearing reggos, discarded ruby crown, and holding in his hand to scepter which reggos hand so often thrown at his head."] + EXPECTED_TEXT_4 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like a shampooer in a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does, get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter, M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators were trivialities not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance. And brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, he could be no fool. Out, there was silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from Irohog. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's for us to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they are asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there is nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. And exactly we've turned Calico, where is my brother now in Quaragejji, in the metal forest? Where is that? The metal forest is in the great donned cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedzeeth thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as we're good to have used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown. And holding in his hand to scepter which reggos had so often thrown at his head."] # fmt: on processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) @@ -2260,7 +2457,7 @@ def test_whisper_longform_multi_batch_prev_cond(self): decoded_single = [] for audio in audios: inputs = processor(audio, return_tensors="pt", truncation=False) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) result = model.generate(**inputs, **gen_kwargs) decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) @@ -2288,7 +2485,7 @@ def test_whisper_longform_multi_batch_hard(self): processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("distil-whisper/meanwhile", "default")["test"] ds = ds.cast_column("audio", Audio(sampling_rate=16000)) @@ -2301,7 +2498,7 @@ def test_whisper_longform_multi_batch_hard(self): decoded_single = [] for audio in audios: inputs = processor(audio, return_tensors="pt", truncation=False, sampling_rate=16_000) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) result = model.generate(**inputs, return_timestamps=True) decoded_single += processor.batch_decode(result, skip_special_tokens=True) @@ -2309,7 +2506,7 @@ def test_whisper_longform_multi_batch_hard(self): inputs = processor( audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True ) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) result = model.generate(**inputs, return_timestamps=True) decoded_all = processor.batch_decode(result, skip_special_tokens=True) @@ -2328,14 +2525,14 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): " Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle, and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ickel Greg Waferandi, who carefully died them in a pallet of bright, zesty shades, and adorn them in the finest most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, and finally attach a mallet hammered strap, perled hardware, and close-shet to create for you the one of a kind hope, kutur, earn-may is burkin bag that is my monologue, but sometimes, sometimes, sometimes. Sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Kony Island, where I'm hiding from the triads, I have some engine lubricants out of a safe way bag and staggered down the shore to tear the sail off a beach sooner than I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel Lovelyfokes, and use it to stitch the sail into a loose pouch like rock sack, and I stole a bag of a garbage truck to the junkyard, where I picked through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out Bindle of news that is my segment.", " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hard-seller and a second pill, as I stole from a park damsel, and it's then arm wrestle a raccoon in the back alley vision quest of news that is my segment.", " You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, beceived, melee, container ship that picked me up floating on the detainees. Then after I sunstroke in juice, realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe in a pool chain that accepting my new role as captain and declaring myself king of the wind arc seas. I grab a dirty muck bucket covered in barnacles and a dornet with the teeth of the vanquished to create the softening wet pirate crown of news that is my segment. I'm going to use the white paper to create the softened white paper to create the softened white paper to create the softened white pirate crown of news that is my segment. Meanwhile.", - " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks I wake up in the baggage hole of Greyhound bus, it's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants and as ovenmets to extract and serve the demented transients pound cake of news that is my segment. Me wild!", - " Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Sloering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seasons to life before me. And the hideous collection of loose animal parts and corrupted men tissue that is my segment. Meanwhile.", + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks I wake up in the baggage hole of Greyhound bus, it's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants and as ovenmets to extract and serve the demented transients pound cake of news that is my segment.", + " Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Sloering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seasons to life before me. And the hideous collection of loose animal parts and corrupted men tissue that is my segment.", ] # fmt: on processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") - model = model.to("cuda") + model = model.to(torch_device) ds = load_dataset("distil-whisper/meanwhile", "default")["test"] ds = ds.cast_column("audio", Audio(sampling_rate=16000)) @@ -2348,7 +2545,7 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): inputs = processor( audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True ) - inputs = inputs.to(device="cuda") + inputs = inputs.to(device=torch_device) gen_kwargs = { "return_timestamps": True, diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 7b6a9f30c55ac9..f3a51a4b77961a 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1451,6 +1451,7 @@ def test_slow_unfinished_sequence(self): # Original model wasn't trained with timestamps and has incorrect generation config pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") + # the audio is 4 seconds long audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") out = pipe( @@ -1460,11 +1461,8 @@ def test_slow_unfinished_sequence(self): self.assertEqual( out, { - "chunks": [ - {"text": "", "timestamp": (18.94, 0.02)}, - {"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)}, - ], "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", + "chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}], }, )