Skip to content

Commit

Permalink
Whisper pipeline: implement 'initial_prompt' and 'hotwords' parameters (
Browse files Browse the repository at this point in the history
#1378)

Adds:
* `initial_prompt` parameter ([faster_whisper
reference](https://github.com/SYSTRAN/faster-whisper/blob/203dddb047fd2c3ed2a520fe1416467a527e0f37/faster_whisper/transcribe.py#L732))
- injects initial prompt tokens as a previous transcription into the
first processing window
* `hotwords` parameter ([faster_whisper
reference](https://github.com/SYSTRAN/faster-whisper/blob/203dddb047fd2c3ed2a520fe1416467a527e0f37/faster_whisper/transcribe.py#L768))
- injects hotwords tokens as a previous transcription into the all
processing windows
* Whisper pipeline usage notes in samples

Closes #1150
Ticket: 156888
  • Loading branch information
as-suvorov authored Dec 19, 2024
1 parent 7a02d2b commit 6b92532
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 14 deletions.
85 changes: 85 additions & 0 deletions samples/cpp/whisper_speech_recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,91 @@ timestamps: [0, 2] text: How are you doing today?

See [SUPPORTED_MODELS.md](../../../src/docs/SUPPORTED_MODELS.md#whisper-models) for the list of supported models.

# Whisper pipeline usage

```c++
#include "openvino/genai/whisper_pipeline.hpp"

ov::genai::WhisperPipeline pipeline(model_dir, "CPU");
// Pipeline expects normalized audio with Sample Rate of 16kHz
ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav");
auto result = pipeline.generate(raw_speech);
// How are you doing today?
```
### Transcription
Whisper pipeline predicts the language of the source audio automatically.
```c++
ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav");
auto result = pipeline.generate(raw_speech);
// How are you doing today?
raw_speech = read_wav("fr_sample.wav");
result = pipeline.generate(raw_speech);
// Il s'agit d'une entité très complexe qui consiste...
```

If the source audio languange is know in advance, it can be specified as an argument to `generate` method:

```c++
ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav");
auto result = pipeline.generate(raw_speech, ov::genai::language("<|en|>"));
// How are you doing today?

raw_speech = read_wav("fr_sample.wav");
result = pipeline.generate(raw_speech, ov::genai::language("<|fr|>"));
// Il s'agit d'une entité très complexe qui consiste...
```

### Translation

By default, Whisper performs the task of speech transcription, where the source audio language is the same as the target text language. To perform speech translation, where the target text is in English, set the task to "translate":

```c++
ov::genai::RawSpeechInput raw_speech = read_wav("fr_sample.wav");
auto result = pipeline.generate(raw_speech, ov::genai::task("translate"));
// It is a very complex entity that consists...
```

### Timestamps prediction

The model can predict timestamps. For sentence-level timestamps, pass the `return_timestamps` argument:

```C++
ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav");
auto result = pipeline.generate(raw_speech, ov::genai::return_timestamps(true));

std::cout << std::setprecision(2);
for (auto& chunk : *result.chunks) {
std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n";
}
// timestamps: [0, 2] text: How are you doing today?
```

### Long-Form audio Transcription

The Whisper model is designed to work on audio samples of up to 30s in duration. Whisper pipeline uses sequential chunking algorithm to transcribe audio samples of arbitrary length.
Sequential chunking algorithm uses a "sliding window", transcribing 30-second slices one after the other.

### Initial prompt and hotwords

Whisper pipeline has `initial_prompt` and `hotwords` generate arguments:
* `initial_prompt`: initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing window
* `hotwords`: hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows

The Whisper model can use that context to better understand the speech and maintain a consistent writing style. However, prompts do not need to be genuine transcripts from prior audio segments. Such prompts can be used to steer the model to use particular spellings or styles:

```c++
auto result = pipeline.generate(raw_speech);
// He has gone and gone for good answered Paul Icrom who...

result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome"));
// He has gone and gone for good answered Polychrome who...
```


### Troubleshooting

#### Empty or rubbish output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ int main(int argc, char* argv[]) try {

std::cout << result << "\n";

std::cout << std::setprecision(2);
for (auto& chunk : *result.chunks) {
std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n";
}
Expand Down
87 changes: 87 additions & 0 deletions samples/python/whisper_speech_recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,93 @@ timestamps: [0, 2] text: How are you doing today?

See [SUPPORTED_MODELS.md](../../../src/docs/SUPPORTED_MODELS.md#whisper-models) for the list of supported models.

# Whisper pipeline usage

```python
import openvino_genai
import librosa

def read_wav(filepath):
raw_speech, samplerate = librosa.load(filepath, sr=16000)
return raw_speech.tolist()

pipe = openvino_genai.WhisperPipeline(model_dir, "CPU")
# Pipeline expects normalized audio with Sample Rate of 16kHz
raw_speech = read_wav('how_are_you_doing_today.wav')
result = pipe.generate(raw_speech)
# How are you doing today?
```

### Transcription

Whisper pipeline predicts the language of the source audio automatically.

```python
raw_speech = read_wav('how_are_you_doing_today.wav')
result = pipe.generate(raw_speech)
# How are you doing today?

raw_speech = read_wav('fr_sample.wav')
result = pipe.generate(raw_speech)
# Il s'agit d'une entité très complexe qui consiste...
```

If the source audio languange is know in advance, it can be specified as an argument to `generate` method:

```python
raw_speech = read_wav("how_are_you_doing_today.wav")
result = pipe.generate(raw_speech, language="<|en|>")
# How are you doing today?

raw_speech = read_wav("fr_sample.wav")
result = pipe.generate(raw_speech, language="<|fr|>")
# Il s'agit d'une entité très complexe qui consiste...
```

### Translation

By default, Whisper performs the task of speech transcription, where the source audio language is the same as the target text language. To perform speech translation, where the target text is in English, set the task to "translate":

```python
raw_speech = read_wav("fr_sample.wav")
result = pipe.generate(raw_speech, task="translate")
# It is a very complex entity that consists...
```

### Timestamps prediction

The model can predict timestamps. For sentence-level timestamps, pass the `return_timestamps` argument:

```python
raw_speech = read_wav("how_are_you_doing_today.wav")
result = pipe.generate(raw_speech, return_timestamps=True)

for chunk in result.chunks:
print(f"timestamps: [{chunk.start_ts:.2f}, {chunk.end_ts:.2f}] text: {chunk.text}")
# timestamps: [0.00, 2.00] text: How are you doing today?
```

### Long-Form audio Transcription

The Whisper model is designed to work on audio samples of up to 30s in duration. Whisper pipeline uses sequential chunking algorithm to transcribe audio samples of arbitrary length.
Sequential chunking algorithm uses a "sliding window", transcribing 30-second slices one after the other.

### Initial prompt and hotwords

Whisper pipeline has `initial_prompt` and `hotwords` generate arguments:
* `initial_prompt`: initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing window
* `hotwords`: hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows

The Whisper model can use that context to better understand the speech and maintain a consistent writing style. However, prompts do not need to be genuine transcripts from prior audio segments. Such prompts can be used to steer the model to use particular spellings or styles:

```python
result = pipe.generate(raw_speech)
# He has gone and gone for good answered Paul Icrom who...

result = pipe.generate(raw_speech, initial_prompt="Polychrome")
# He has gone and gone for good answered Polychrome who...
```

### Troubleshooting

#### Empty or rubbish output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument("wav_file_path")
args = parser.parse_args()

device = "CPU" # GPU can be used as well
device = "CPU" # GPU, NPU can be used as well
pipe = openvino_genai.WhisperPipeline(args.model_dir, device)

config = pipe.get_generation_config()
Expand All @@ -34,8 +34,9 @@ def main():

print(result)

for chunk in result.chunks:
print(f"timestamps: [{chunk.start_ts}, {chunk.end_ts}] text: {chunk.text}")
if result.chunks:
for chunk in result.chunks:
print(f"timestamps: [{chunk.start_ts:.2f}, {chunk.end_ts:.2f}] text: {chunk.text}")


if "__main__" == __name__:
Expand Down
34 changes: 33 additions & 1 deletion src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include <optional>
#include <filesystem>
#include <optional>

#include "openvino/genai/tokenizer.hpp"
#include "openvino/runtime/compiled_model.hpp"
Expand Down Expand Up @@ -46,6 +46,9 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
// Transcribe token id.
int64_t transcribe_token_id = 50359;

// Corresponds to the ”<|startofprev|>” token.
int64_t prev_sot_token_id = 50361;

// No timestamps token id.
int64_t no_timestamps_token_id = 50363;

Expand Down Expand Up @@ -75,6 +78,32 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
// Note that a segment of text refers to a sequence of one or more words, rather than individual words.
bool return_timestamps = false;

/*
* Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing
* window. Can be used to steer the model to use particular spellings or styles.
*
* Example:
* auto result = pipeline.generate(raw_speech);
* // He has gone and gone for good answered Paul Icrom who...
*
* auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome"));
* // He has gone and gone for good answered Polychrome who...
*/
std::optional<std::string> initial_prompt = std::nullopt;

/*
* Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows.
* Can be used to steer the model to use particular spellings or styles.
*
* Example:
* auto result = pipeline.generate(raw_speech);
* // He has gone and gone for good answered Paul Icrom who...
*
* auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome"));
* // He has gone and gone for good answered Polychrome who...
*/
std::optional<std::string> hotwords = std::nullopt;

// A list containing tokens that will be suppressed at the beginning of the sampling process.
std::vector<int64_t> begin_suppress_tokens;

Expand Down Expand Up @@ -111,9 +140,12 @@ static constexpr ov::Property<int64_t> pad_token_id{"pad_token_id"};
static constexpr ov::Property<int64_t> transcribe_token_id{"transcribe_token_id"};
static constexpr ov::Property<int64_t> translate_token_id{"translate_token_id"};
static constexpr ov::Property<int64_t> no_timestamps_token_id{"no_timestamps_token_id"};
static constexpr ov::Property<int64_t> prev_sot_token_id{"prev_sot_token_id"};
static constexpr ov::Property<std::string> language{"language"};
static constexpr ov::Property<std::string> task{"task"};
static constexpr ov::Property<bool> return_timestamps{"return_timestamps"};
static constexpr ov::Property<std::string> initial_prompt{"initial_prompt"};
static constexpr ov::Property<std::string> hotwords{"hotwords"};
static constexpr ov::Property<std::map<std::string, int64_t>> lang_to_id{"lang_to_id"};

} // namespace genai
Expand Down
89 changes: 89 additions & 0 deletions src/cpp/src/whisper/context_tokens.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "context_tokens.hpp"

namespace {
std::pair<std::vector<int64_t>, float> tokenize(std::string&& text,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::Tokenizer& tokenizer) {
if (text.empty()) {
return {{}, 0.0f};
}

auto start_time = std::chrono::steady_clock::now();
auto encoded = tokenizer.encode(text, ov::genai::add_special_tokens(false));
auto duration = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start_time);

auto input_ids = encoded.input_ids;
auto input_ids_data = input_ids.data<int64_t>();

std::vector<int64_t> prompt_tokens;
prompt_tokens.reserve(input_ids.get_size());

// even with ov::genai::add_special_tokens(false) tokenizer adds next special tokens. Ticket: 159569
std::set<int64_t> special_tokens{config.decoder_start_token_id, config.eos_token_id, config.no_timestamps_token_id};

for (size_t i = 0; i < input_ids.get_size(); i++) {
if (special_tokens.count(input_ids_data[i])) {
continue;
}

prompt_tokens.emplace_back(input_ids_data[i]);
}

return {prompt_tokens, duration};
}
} // namespace

namespace ov {
namespace genai {

std::pair<WhisperContextTokens, float> prepare_context_tokens(const WhisperGenerationConfig& config,
Tokenizer& tokenizer) {
WhisperContextTokens context_tokens;
float duration = 0.0f;

if (config.initial_prompt.has_value()) {
auto [initial_prompt_tokens, initial_prompt_duration] =
tokenize(" " + *config.initial_prompt, config, tokenizer);
context_tokens.initial_prompt = std::move(initial_prompt_tokens);
duration += initial_prompt_duration;
}

if (config.hotwords.has_value()) {
auto [hotwords_tokens, hotwords_duration] = tokenize(" " + *config.hotwords, config, tokenizer);
context_tokens.hotwords = std::move(hotwords_tokens);
duration += hotwords_duration;
}

return {context_tokens, duration};
}

std::vector<int64_t> get_prompt_tokens(const WhisperContextTokens& context_tokens,
const WhisperGenerationConfig& config,
size_t chunk_offset) {
bool should_add_initial_prompt = !context_tokens.initial_prompt.empty() && chunk_offset == 0;
bool should_add_hotwords = !context_tokens.hotwords.empty();

if (!should_add_initial_prompt && !should_add_hotwords) {
return {};
}

std::vector<int64_t> prompt_tokens{config.prev_sot_token_id};

if (should_add_initial_prompt) {
prompt_tokens.insert(prompt_tokens.end(),
context_tokens.initial_prompt.begin(),
context_tokens.initial_prompt.end());
}

if (should_add_hotwords) {
prompt_tokens.insert(prompt_tokens.end(), context_tokens.hotwords.begin(), context_tokens.hotwords.end());
}

return prompt_tokens;
}

} // namespace genai
} // namespace ov
25 changes: 25 additions & 0 deletions src/cpp/src/whisper/context_tokens.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "openvino/genai/perf_metrics.hpp"
#include "openvino/genai/whisper_generation_config.hpp"

namespace ov {
namespace genai {

struct WhisperContextTokens {
std::vector<int64_t> initial_prompt;
std::vector<int64_t> hotwords;
};

std::pair<WhisperContextTokens, float> prepare_context_tokens(const WhisperGenerationConfig& config,
Tokenizer& tokenizer);

std::vector<int64_t> get_prompt_tokens(const WhisperContextTokens& context_tokens,
const WhisperGenerationConfig& config,
size_t chunk_offset);

} // namespace genai
} // namespace ov
Loading

0 comments on commit 6b92532

Please sign in to comment.