From 0ed6729bb130cb1d43fb2ede60b0c50f9ee14d68 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 2 Nov 2023 17:06:56 +0000 Subject: [PATCH] Enrich TTS pipeline parameters naming (#26473) * enrich TTS pipeline docstring for clearer forward_params use * change token leghts * update Pipeline parameters * correct docstring and make style * fix tests * make style * change music prompt Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * raise errors if generate_kwargs with forward-only models * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/pipelines/text_to_audio.py | 62 ++++++++++++++++--- .../pipelines/test_pipelines_text_to_audio.py | 55 ++++++++++++++++ 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 299fa7ac014b01..58c21cc1216869 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -43,6 +43,29 @@ class TextToAudioPipeline(Pipeline): Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + + You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or + [`TextToAudioPipeline.__call__.generate_kwargs`]. + + Example: + + ```python + >>> from transformers import pipeline + + >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + + >>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length + >>> generate_kwargs = { + ... "do_sample": True, + ... "temperature": 0.7, + ... "max_new_tokens": 35, + ... } + + >>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs) + ``` + + This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or `"text-to-audio"`. @@ -107,11 +130,26 @@ def preprocess(self, text, **kwargs): def _forward(self, model_inputs, **kwargs): # we expect some kwargs to be additional tensors which need to be on the right device kwargs = self._ensure_tensor_on_device(kwargs, device=self.device) + forward_params = kwargs["forward_params"] + generate_kwargs = kwargs["generate_kwargs"] if self.model.can_generate(): - output = self.model.generate(**model_inputs, **kwargs) + # we expect some kwargs to be additional tensors which need to be on the right device + generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device) + + # generate_kwargs get priority over forward_params + forward_params.update(generate_kwargs) + + output = self.model.generate(**model_inputs, **forward_params) else: - output = self.model(**model_inputs, **kwargs)[0] + if len(generate_kwargs): + raise ValueError( + f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty. + For forward-only TTA models, please use `forward_params` instead of of + `generate_kwargs`. For reference, here are the `generate_kwargs` used here: + {generate_kwargs.keys()}""" + ) + output = self.model(**model_inputs, **forward_params)[0] if self.vocoder is not None: # in that case, the output is a spectrogram that needs to be converted into a waveform @@ -126,8 +164,14 @@ def __call__(self, text_inputs: Union[str, List[str]], **forward_params): Args: text_inputs (`str` or `List[str]`): The text(s) to generate. - forward_params (*optional*): - Parameters passed to the model generation/forward method. + forward_params (`dict`, *optional*): + Parameters passed to the model generation/forward method. `forward_params` are always passed to the + underlying model. + generate_kwargs (`dict`, *optional*): + The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a + complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are + only passed to the underlying model if the latter is a generative model. Return: A `dict` or a list of `dict`: The dictionaries have two keys: @@ -141,14 +185,18 @@ def _sanitize_parameters( self, preprocess_params=None, forward_params=None, + generate_kwargs=None, ): + params = { + "forward_params": forward_params if forward_params else {}, + "generate_kwargs": generate_kwargs if generate_kwargs else {}, + } + if preprocess_params is None: preprocess_params = {} - if forward_params is None: - forward_params = {} postprocess_params = {} - return preprocess_params, forward_params, postprocess_params + return preprocess_params, params, postprocess_params def postprocess(self, waveform): output_dict = {} diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index 6aca34ed98a097..a9f1eccae5089c 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -30,6 +30,7 @@ slow, torch_device, ) +from transformers.trainer_utils import set_seed from .test_pipelines_common import ANY @@ -174,6 +175,60 @@ def test_vits_model_pt(self): outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2) self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + @slow + @require_torch + def test_forward_model_kwargs(self): + # use vits - a forward model + speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt") + + # for reproducibility + set_seed(555) + outputs = speech_generator("This is a test", forward_params={"speaker_id": 5}) + audio = outputs["audio"] + + with self.assertRaises(TypeError): + # assert error if generate parameter + outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True}) + + forward_params = {"speaker_id": 5} + generate_kwargs = {"do_sample": True} + + with self.assertRaises(ValueError): + # assert error if generate_kwargs with forward-only models + outputs = speech_generator( + "This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs + ) + self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5) + + @slow + @require_torch + def test_generative_model_kwargs(self): + # use musicgen - a generative model + music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + + forward_params = { + "do_sample": True, + "max_new_tokens": 250, + } + + # for reproducibility + set_seed(555) + outputs = music_generator("This is a test", forward_params=forward_params) + audio = outputs["audio"] + self.assertEqual(ANY(np.ndarray), audio) + + # make sure generate kwargs get priority over forward params + forward_params = { + "do_sample": False, + "max_new_tokens": 250, + } + generate_kwargs = {"do_sample": True} + + # for reproducibility + set_seed(555) + outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs) + self.assertListEqual(outputs["audio"].tolist(), audio.tolist()) + def get_test_pipeline(self, model, tokenizer, processor): speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer) return speech_generator, ["This is a test", "Another test"]