Skip to content

Commit

Permalink
improve docs/errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Sep 28, 2023
1 parent 098c3f4 commit ae8e277
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 38 deletions.
87 changes: 52 additions & 35 deletions docs/source/en/pipeline_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,35 @@ Take a look at the [`pipeline`] documentation for a complete list of supported t

## Pipeline usage

While each task has an associated [`pipeline`], it is simpler to use the general [`pipeline`] abstraction which contains all the task-specific pipelines. The [`pipeline`] automatically loads a default model and a preprocessing class capable of inference for your task.
While each task has an associated [`pipeline`], it is simpler to use the general [`pipeline`] abstraction which contains
all the task-specific pipelines. The [`pipeline`] automatically loads a default model and a preprocessing class capable
of inference for your task. Let's take the example of using the [`pipeline`] for automatic speech recognition (ASR), or
speech-to-text.

1. Start by creating a [`pipeline`] and specify an inference task:

1. Start by creating a [`pipeline`] and specify the inference task:

```py
>>> from transformers import pipeline

>>> generator = pipeline(task="automatic-speech-recognition")
>>> transcriber = pipeline(task="automatic-speech-recognition")
```

2. Pass your input text to the [`pipeline`]:
2. Pass your input to the [`pipeline`]. In the case of speech recognition, this is an audio input file:

```py
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': 'I HAVE A DREAM BUT ONE DAY THIS NATION WILL RISE UP LIVE UP THE TRUE MEANING OF ITS TREES'}
```

Not the result you had in mind? Check out some of the [most downloaded automatic speech recognition models](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=downloads) on the Hub to see if you can get a better transcription.
Let's try [openai/whisper-large](https://huggingface.co/openai/whisper-large):
Not the result you had in mind? Check out some of the [most downloaded automatic speech recognition models](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=trending)
on the Hub to see if you can get a better transcription.

Let's try the [Whisper large-v2](https://huggingface.co/openai/whisper-large) model from OpenAI:

```py
>>> generator = pipeline(model="openai/whisper-large")
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
>>> transcriber = pipeline(model="openai/whisper-large-v2")
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
```

Expand All @@ -65,30 +71,30 @@ And if you don't find a model for your use case, you can always start [training]
If you have several inputs, you can pass your input as a list:

```py
generator(
transcriber(
[
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
]
)
```

If you want to iterate over a whole dataset, or want to use it for inference in a webserver, check out dedicated parts

[Using pipelines on a dataset](#using-pipelines-on-a-dataset)

[Using pipelines for a webserver](./pipeline_webserver)
If you want to iterate over a whole dataset, or want to use it for inference in a webserver, check out the dedicated parts
of the docs:
* [Using pipelines on a dataset](#using-pipelines-on-a-dataset)
* [Using pipelines for a webserver](./pipeline_webserver)

## Parameters

[`pipeline`] supports many parameters; some are task specific, and some are general to all pipelines.
In general you can specify parameters anywhere you want:
In general, you can specify parameters anywhere you want:

```py
generator = pipeline(model="openai/whisper-large", my_parameter=1)
out = generator(...) # This will use `my_parameter=1`.
out = generator(..., my_parameter=2) # This will override and use `my_parameter=2`.
out = generator(...) # This will go back to using `my_parameter=1`.
transcriber = pipeline(model="openai/whisper-large-v2", my_parameter=1)

out = transcriber(...) # This will use `my_parameter=1`.
out = transcriber(..., my_parameter=2) # This will override and use `my_parameter=2`.
out = transcriber(...) # This will go back to using `my_parameter=1`.
```

Let's check out 3 important ones:
Expand All @@ -99,14 +105,20 @@ If you use `device=n`, the pipeline automatically puts the model on the specifie
This will work regardless of whether you are using PyTorch or Tensorflow.

```py
generator = pipeline(model="openai/whisper-large", device=0)
transcriber = pipeline(model="openai/whisper-large-v2", device=0)
```

If the model is too large for a single GPU, you can set `device_map="auto"` to allow 🤗 [Accelerate](https://huggingface.co/docs/accelerate) to automatically determine how to load and store the model weights.
Ensure Accelerate is installed before using an auto-map:

```bash
pip install --upgrade accelerate
```

The following code then automatically loads and stores model weights across devices:

```py
#!pip install accelerate
generator = pipeline(model="openai/whisper-large", device_map="auto")
transcriber = pipeline(model="openai/whisper-large-v2", device_map="auto")
```

Note that if `device_map="auto"` is passed, there is no need to add the argument `device=device` when instantiating your `pipeline` as you may encounter some unexpected behavior!
Expand All @@ -118,12 +130,12 @@ By default, pipelines will not batch inference for reasons explained in detail [
But if it works in your use case, you can use:

```py
generator = pipeline(model="openai/whisper-large", device=0, batch_size=2)
audio_filenames = [f"audio_{i}.flac" for i in range(10)]
texts = generator(audio_filenames)
transcriber = pipeline(model="openai/whisper-large-v2", device=0, batch_size=2)
audio_filenames = [f"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/{i}.flac" for i in range(1, 5)]
texts = transcriber(audio_filenames)
```

This runs the pipeline on the 10 provided audio files, but it will pass them in batches of 2
This runs the pipeline on the 4 provided audio files, but it will pass them in batches of 2
to the model (which is on a GPU, where batching is more likely to help) without requiring any further code from you.
The output should always match what you would have received without batching. It is only meant as a way to help you get more speed out of a pipeline.

Expand All @@ -136,18 +148,23 @@ For instance, the [`transformers.AutomaticSpeechRecognitionPipeline.__call__`] m


```py
>>> # Not using whisper, as it cannot provide timestamps.
>>> generator = pipeline(model="facebook/wav2vec2-large-960h-lv60-self", return_timestamps="word")
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': 'I HAVE A DREAM BUT ONE DAY THIS NATION WILL RISE UP AND LIVE OUT THE TRUE MEANING OF ITS CREED', 'chunks': [{'text': 'I', 'timestamp': (1.22, 1.24)}, {'text': 'HAVE', 'timestamp': (1.42, 1.58)}, {'text': 'A', 'timestamp': (1.66, 1.68)}, {'text': 'DREAM', 'timestamp': (1.76, 2.14)}, {'text': 'BUT', 'timestamp': (3.68, 3.8)}, {'text': 'ONE', 'timestamp': (3.94, 4.06)}, {'text': 'DAY', 'timestamp': (4.16, 4.3)}, {'text': 'THIS', 'timestamp': (6.36, 6.54)}, {'text': 'NATION', 'timestamp': (6.68, 7.1)}, {'text': 'WILL', 'timestamp': (7.32, 7.56)}, {'text': 'RISE', 'timestamp': (7.8, 8.26)}, {'text': 'UP', 'timestamp': (8.38, 8.48)}, {'text': 'AND', 'timestamp': (10.08, 10.18)}, {'text': 'LIVE', 'timestamp': (10.26, 10.48)}, {'text': 'OUT', 'timestamp': (10.58, 10.7)}, {'text': 'THE', 'timestamp': (10.82, 10.9)}, {'text': 'TRUE', 'timestamp': (10.98, 11.18)}, {'text': 'MEANING', 'timestamp': (11.26, 11.58)}, {'text': 'OF', 'timestamp': (11.66, 11.7)}, {'text': 'ITS', 'timestamp': (11.76, 11.88)}, {'text': 'CREED', 'timestamp': (12.0, 12.38)}]}
>>> transcriber = pipeline(model="openai/whisper-large-v2", return_timestamps=True)
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.', 'chunks': [{'timestamp': (0.0, 11.88), 'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its'}, {'timestamp': (11.88, 12.38), 'text': ' creed.'}]}
```

As you can see, the model inferred the text and also outputted **when** the various words were pronounced
in the sentence.
As you can see, the model inferred the text and also outputted **when** the various sentences were pronounced.

There are many parameters available for each task, so check out each task's API reference to see what you can tinker with!
For instance, the [`~transformers.AutomaticSpeechRecognitionPipeline`] has a `chunk_length_s` parameter which is helpful for working on really long audio files (for example, subtitling entire movies or hour-long videos) that a model typically cannot handle on its own.

For instance, the [`~transformers.AutomaticSpeechRecognitionPipeline`] has a `chunk_length_s` parameter which is helpful
for working on really long audio files (for example, subtitling entire movies or hour-long videos) that a model typically
cannot handle on its own:

```python
>>> transcriber = pipeline(model="openai/whisper-large-v2", chunk_length_s=30, return_timestamps=True)
>>> transcriber("https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav")
{'text': " Chapter 16. I might have told you of the beginning of this liaison in a few lines, but I wanted you to see every step by which we came. I, too, agree to whatever Marguerite wished, Marguerite to be unable to live apart from me. It was the day after the evening...
```

If you can't find a parameter that would really help you out, feel free to [request it](https://github.com/huggingface/transformers/issues/new?assignees=&labels=feature&template=feature-request.yml)!

Expand Down
6 changes: 5 additions & 1 deletion src/transformers/pipelines/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
out_bytes = output_stream[0]
audio = np.frombuffer(out_bytes, np.float32)
if audio.shape[0] == 0:
raise ValueError("Malformed soundfile")
raise ValueError(
"Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has "
"a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote "
"URL, ensure that the URL is the full address to **download** the audio file."
)
return audio


Expand Down
5 changes: 3 additions & 2 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ def __call__(
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `str` that is either the filename of a local audio file, or a public URL address to download the
audio file. The file will be read at the correct sampling rate to get the waveform using
*ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Expand Down

0 comments on commit ae8e277

Please sign in to comment.