-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Whisper pipeline: implement 'initial_prompt' and 'hotwords' parameters (
#1378) Adds: * `initial_prompt` parameter ([faster_whisper reference](https://github.com/SYSTRAN/faster-whisper/blob/203dddb047fd2c3ed2a520fe1416467a527e0f37/faster_whisper/transcribe.py#L732)) - injects initial prompt tokens as a previous transcription into the first processing window * `hotwords` parameter ([faster_whisper reference](https://github.com/SYSTRAN/faster-whisper/blob/203dddb047fd2c3ed2a520fe1416467a527e0f37/faster_whisper/transcribe.py#L768)) - injects hotwords tokens as a previous transcription into the all processing windows * Whisper pipeline usage notes in samples Closes #1150 Ticket: 156888
- Loading branch information
1 parent
7a02d2b
commit 6b92532
Showing
15 changed files
with
460 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (C) 2023-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "context_tokens.hpp" | ||
|
||
namespace { | ||
std::pair<std::vector<int64_t>, float> tokenize(std::string&& text, | ||
const ov::genai::WhisperGenerationConfig& config, | ||
ov::genai::Tokenizer& tokenizer) { | ||
if (text.empty()) { | ||
return {{}, 0.0f}; | ||
} | ||
|
||
auto start_time = std::chrono::steady_clock::now(); | ||
auto encoded = tokenizer.encode(text, ov::genai::add_special_tokens(false)); | ||
auto duration = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start_time); | ||
|
||
auto input_ids = encoded.input_ids; | ||
auto input_ids_data = input_ids.data<int64_t>(); | ||
|
||
std::vector<int64_t> prompt_tokens; | ||
prompt_tokens.reserve(input_ids.get_size()); | ||
|
||
// even with ov::genai::add_special_tokens(false) tokenizer adds next special tokens. Ticket: 159569 | ||
std::set<int64_t> special_tokens{config.decoder_start_token_id, config.eos_token_id, config.no_timestamps_token_id}; | ||
|
||
for (size_t i = 0; i < input_ids.get_size(); i++) { | ||
if (special_tokens.count(input_ids_data[i])) { | ||
continue; | ||
} | ||
|
||
prompt_tokens.emplace_back(input_ids_data[i]); | ||
} | ||
|
||
return {prompt_tokens, duration}; | ||
} | ||
} // namespace | ||
|
||
namespace ov { | ||
namespace genai { | ||
|
||
std::pair<WhisperContextTokens, float> prepare_context_tokens(const WhisperGenerationConfig& config, | ||
Tokenizer& tokenizer) { | ||
WhisperContextTokens context_tokens; | ||
float duration = 0.0f; | ||
|
||
if (config.initial_prompt.has_value()) { | ||
auto [initial_prompt_tokens, initial_prompt_duration] = | ||
tokenize(" " + *config.initial_prompt, config, tokenizer); | ||
context_tokens.initial_prompt = std::move(initial_prompt_tokens); | ||
duration += initial_prompt_duration; | ||
} | ||
|
||
if (config.hotwords.has_value()) { | ||
auto [hotwords_tokens, hotwords_duration] = tokenize(" " + *config.hotwords, config, tokenizer); | ||
context_tokens.hotwords = std::move(hotwords_tokens); | ||
duration += hotwords_duration; | ||
} | ||
|
||
return {context_tokens, duration}; | ||
} | ||
|
||
std::vector<int64_t> get_prompt_tokens(const WhisperContextTokens& context_tokens, | ||
const WhisperGenerationConfig& config, | ||
size_t chunk_offset) { | ||
bool should_add_initial_prompt = !context_tokens.initial_prompt.empty() && chunk_offset == 0; | ||
bool should_add_hotwords = !context_tokens.hotwords.empty(); | ||
|
||
if (!should_add_initial_prompt && !should_add_hotwords) { | ||
return {}; | ||
} | ||
|
||
std::vector<int64_t> prompt_tokens{config.prev_sot_token_id}; | ||
|
||
if (should_add_initial_prompt) { | ||
prompt_tokens.insert(prompt_tokens.end(), | ||
context_tokens.initial_prompt.begin(), | ||
context_tokens.initial_prompt.end()); | ||
} | ||
|
||
if (should_add_hotwords) { | ||
prompt_tokens.insert(prompt_tokens.end(), context_tokens.hotwords.begin(), context_tokens.hotwords.end()); | ||
} | ||
|
||
return prompt_tokens; | ||
} | ||
|
||
} // namespace genai | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (C) 2023-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "openvino/genai/perf_metrics.hpp" | ||
#include "openvino/genai/whisper_generation_config.hpp" | ||
|
||
namespace ov { | ||
namespace genai { | ||
|
||
struct WhisperContextTokens { | ||
std::vector<int64_t> initial_prompt; | ||
std::vector<int64_t> hotwords; | ||
}; | ||
|
||
std::pair<WhisperContextTokens, float> prepare_context_tokens(const WhisperGenerationConfig& config, | ||
Tokenizer& tokenizer); | ||
|
||
std::vector<int64_t> get_prompt_tokens(const WhisperContextTokens& context_tokens, | ||
const WhisperGenerationConfig& config, | ||
size_t chunk_offset); | ||
|
||
} // namespace genai | ||
} // namespace ov |
Oops, something went wrong.