Skip to content

Commit

Permalink
Enrich TTS pipeline parameters naming (#26473)
Browse files Browse the repository at this point in the history
* enrich TTS pipeline docstring for clearer forward_params use

* change token leghts

* update Pipeline parameters

* correct docstring and make style

* fix tests

* make style

* change music prompt

Co-authored-by: Sanchit Gandhi <[email protected]>

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>

* raise errors if generate_kwargs with forward-only models

* make style

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Nov 2, 2023
1 parent 147e8ce commit 0ed6729
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
62 changes: 55 additions & 7 deletions src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ class TextToAudioPipeline(Pipeline):
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
<Tip>
You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
[`TextToAudioPipeline.__call__.generate_kwargs`].
Example:
```python
>>> from transformers import pipeline
>>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
>>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
>>> generate_kwargs = {
... "do_sample": True,
... "temperature": 0.7,
... "max_new_tokens": 35,
... }
>>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
```
</Tip>
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
`"text-to-audio"`.
Expand Down Expand Up @@ -107,11 +130,26 @@ def preprocess(self, text, **kwargs):
def _forward(self, model_inputs, **kwargs):
# we expect some kwargs to be additional tensors which need to be on the right device
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
forward_params = kwargs["forward_params"]
generate_kwargs = kwargs["generate_kwargs"]

if self.model.can_generate():
output = self.model.generate(**model_inputs, **kwargs)
# we expect some kwargs to be additional tensors which need to be on the right device
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)

# generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs)

output = self.model.generate(**model_inputs, **forward_params)
else:
output = self.model(**model_inputs, **kwargs)[0]
if len(generate_kwargs):
raise ValueError(
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
For forward-only TTA models, please use `forward_params` instead of of
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
{generate_kwargs.keys()}"""
)
output = self.model(**model_inputs, **forward_params)[0]

if self.vocoder is not None:
# in that case, the output is a spectrogram that needs to be converted into a waveform
Expand All @@ -126,8 +164,14 @@ def __call__(self, text_inputs: Union[str, List[str]], **forward_params):
Args:
text_inputs (`str` or `List[str]`):
The text(s) to generate.
forward_params (*optional*):
Parameters passed to the model generation/forward method.
forward_params (`dict`, *optional*):
Parameters passed to the model generation/forward method. `forward_params` are always passed to the
underlying model.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
only passed to the underlying model if the latter is a generative model.
Return:
A `dict` or a list of `dict`: The dictionaries have two keys:
Expand All @@ -141,14 +185,18 @@ def _sanitize_parameters(
self,
preprocess_params=None,
forward_params=None,
generate_kwargs=None,
):
params = {
"forward_params": forward_params if forward_params else {},
"generate_kwargs": generate_kwargs if generate_kwargs else {},
}

if preprocess_params is None:
preprocess_params = {}
if forward_params is None:
forward_params = {}
postprocess_params = {}

return preprocess_params, forward_params, postprocess_params
return preprocess_params, params, postprocess_params

def postprocess(self, waveform):
output_dict = {}
Expand Down
55 changes: 55 additions & 0 deletions tests/pipelines/test_pipelines_text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
slow,
torch_device,
)
from transformers.trainer_utils import set_seed

from .test_pipelines_common import ANY

Expand Down Expand Up @@ -174,6 +175,60 @@ def test_vits_model_pt(self):
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])

@slow
@require_torch
def test_forward_model_kwargs(self):
# use vits - a forward model
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")

# for reproducibility
set_seed(555)
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
audio = outputs["audio"]

with self.assertRaises(TypeError):
# assert error if generate parameter
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})

forward_params = {"speaker_id": 5}
generate_kwargs = {"do_sample": True}

with self.assertRaises(ValueError):
# assert error if generate_kwargs with forward-only models
outputs = speech_generator(
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
)
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)

@slow
@require_torch
def test_generative_model_kwargs(self):
# use musicgen - a generative model
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")

forward_params = {
"do_sample": True,
"max_new_tokens": 250,
}

# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)

# make sure generate kwargs get priority over forward params
forward_params = {
"do_sample": False,
"max_new_tokens": 250,
}
generate_kwargs = {"do_sample": True}

# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())

def get_test_pipeline(self, model, tokenizer, processor):
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
return speech_generator, ["This is a test", "Another test"]
Expand Down

0 comments on commit 0ed6729

Please sign in to comment.