diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/llm_bench-python.yml index 1999bafcfe..8356805e19 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/llm_bench-python.yml @@ -151,7 +151,7 @@ jobs: rm -rf ./ov_models/internvl2-1B - name: WWB Tests run: | - pip install git+https://github.com/huggingface/optimum-intel.git + pip install git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} python -m pytest -v ${{ env.WWB_PATH }}/tests stateful: @@ -190,7 +190,7 @@ jobs: - name: WWB Tests run: | pip install pytest - pip install git+https://github.com/huggingface/optimum-intel.git + pip install git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} python -m pytest -v ${{ env.WWB_PATH }}/tests diff --git a/README.md b/README.md index c2509528c3..be3de5e8ce 100644 --- a/README.md +++ b/README.md @@ -331,10 +331,14 @@ For more examples check out our [Generative AI workflow](https://docs.openvino.a NOTE: Whisper Pipeline requires preprocessing of audio input (to adjust sampling rate and normalize) - ### Converting and compressing image generation model from Hugging Face library + ### Converting and quantizing speech-to-text model from Hugging Face library ```sh #Download and convert to OpenVINO whisper-base model optimum-cli export openvino --trust-remote-code --model openai/whisper-base whisper-base + +#Download, convert and apply int8 static quantization to whisper-base model +optimum-cli export openvino --trust-remote-code --model openai/whisper-base \ +--quant-mode int8 --dataset librispeech --num-samples 32 whisper-base-int8 ``` ### Run generation using Whisper Pipeline API in Python diff --git a/samples/cpp/whisper_speech_recognition/README.md b/samples/cpp/whisper_speech_recognition/README.md index 773135b648..d649266613 100644 --- a/samples/cpp/whisper_speech_recognition/README.md +++ b/samples/cpp/whisper_speech_recognition/README.md @@ -33,6 +33,91 @@ timestamps: [0, 2] text: How are you doing today? See [SUPPORTED_MODELS.md](../../../src/docs/SUPPORTED_MODELS.md#whisper-models) for the list of supported models. +# Whisper pipeline usage + +```c++ +#include "openvino/genai/whisper_pipeline.hpp" + +ov::genai::WhisperPipeline pipeline(model_dir, "CPU"); +// Pipeline expects normalized audio with Sample Rate of 16kHz +ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav"); +auto result = pipeline.generate(raw_speech); +// How are you doing today? +``` + +### Transcription + +Whisper pipeline predicts the language of the source audio automatically. + +```c++ +ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav"); +auto result = pipeline.generate(raw_speech); +// How are you doing today? + +raw_speech = read_wav("fr_sample.wav"); +result = pipeline.generate(raw_speech); +// Il s'agit d'une entité très complexe qui consiste... +``` + +If the source audio languange is know in advance, it can be specified as an argument to `generate` method: + +```c++ +ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav"); +auto result = pipeline.generate(raw_speech, ov::genai::language("<|en|>")); +// How are you doing today? + +raw_speech = read_wav("fr_sample.wav"); +result = pipeline.generate(raw_speech, ov::genai::language("<|fr|>")); +// Il s'agit d'une entité très complexe qui consiste... +``` + +### Translation + +By default, Whisper performs the task of speech transcription, where the source audio language is the same as the target text language. To perform speech translation, where the target text is in English, set the task to "translate": + +```c++ +ov::genai::RawSpeechInput raw_speech = read_wav("fr_sample.wav"); +auto result = pipeline.generate(raw_speech, ov::genai::task("translate")); +// It is a very complex entity that consists... +``` + +### Timestamps prediction + +The model can predict timestamps. For sentence-level timestamps, pass the `return_timestamps` argument: + +```C++ +ov::genai::RawSpeechInput raw_speech = read_wav("how_are_you_doing_today.wav"); +auto result = pipeline.generate(raw_speech, ov::genai::return_timestamps(true)); + +std::cout << std::setprecision(2); +for (auto& chunk : *result.chunks) { + std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n"; +} +// timestamps: [0, 2] text: How are you doing today? +``` + +### Long-Form audio Transcription + +The Whisper model is designed to work on audio samples of up to 30s in duration. Whisper pipeline uses sequential chunking algorithm to transcribe audio samples of arbitrary length. +Sequential chunking algorithm uses a "sliding window", transcribing 30-second slices one after the other. + +### Initial prompt and hotwords + +Whisper pipeline has `initial_prompt` and `hotwords` generate arguments: +* `initial_prompt`: initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing window +* `hotwords`: hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows + +The Whisper model can use that context to better understand the speech and maintain a consistent writing style. However, prompts do not need to be genuine transcripts from prior audio segments. Such prompts can be used to steer the model to use particular spellings or styles: + +```c++ +auto result = pipeline.generate(raw_speech); +// He has gone and gone for good answered Paul Icrom who... + +result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); +// He has gone and gone for good answered Polychrome who... +``` + + ### Troubleshooting #### Empty or rubbish output diff --git a/samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp b/samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp index 31d3f8c551..3df17a77f5 100644 --- a/samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp +++ b/samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp @@ -28,6 +28,7 @@ int main(int argc, char* argv[]) try { std::cout << result << "\n"; + std::cout << std::setprecision(2); for (auto& chunk : *result.chunks) { std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n"; } diff --git a/samples/export-requirements.txt b/samples/export-requirements.txt index 797b680b9a..d75fdbacee 100644 --- a/samples/export-requirements.txt +++ b/samples/export-requirements.txt @@ -2,7 +2,7 @@ --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly openvino-tokenizers~=2025.0.0.0.dev -optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a numpy<2.0.0; sys_platform == 'darwin' einops==0.8.0 # For Qwen transformers_stream_generator==0.0.5 # For Qwen diff --git a/samples/python/whisper_speech_recognition/README.md b/samples/python/whisper_speech_recognition/README.md index 158bd18311..aeb46444bf 100644 --- a/samples/python/whisper_speech_recognition/README.md +++ b/samples/python/whisper_speech_recognition/README.md @@ -40,6 +40,93 @@ timestamps: [0, 2] text: How are you doing today? See [SUPPORTED_MODELS.md](../../../src/docs/SUPPORTED_MODELS.md#whisper-models) for the list of supported models. +# Whisper pipeline usage + +```python +import openvino_genai +import librosa + +def read_wav(filepath): + raw_speech, samplerate = librosa.load(filepath, sr=16000) + return raw_speech.tolist() + +pipe = openvino_genai.WhisperPipeline(model_dir, "CPU") +# Pipeline expects normalized audio with Sample Rate of 16kHz +raw_speech = read_wav('how_are_you_doing_today.wav') +result = pipe.generate(raw_speech) +# How are you doing today? +``` + +### Transcription + +Whisper pipeline predicts the language of the source audio automatically. + +```python +raw_speech = read_wav('how_are_you_doing_today.wav') +result = pipe.generate(raw_speech) +# How are you doing today? + +raw_speech = read_wav('fr_sample.wav') +result = pipe.generate(raw_speech) +# Il s'agit d'une entité très complexe qui consiste... +``` + +If the source audio languange is know in advance, it can be specified as an argument to `generate` method: + +```python +raw_speech = read_wav("how_are_you_doing_today.wav") +result = pipe.generate(raw_speech, language="<|en|>") +# How are you doing today? + +raw_speech = read_wav("fr_sample.wav") +result = pipe.generate(raw_speech, language="<|fr|>") +# Il s'agit d'une entité très complexe qui consiste... +``` + +### Translation + +By default, Whisper performs the task of speech transcription, where the source audio language is the same as the target text language. To perform speech translation, where the target text is in English, set the task to "translate": + +```python +raw_speech = read_wav("fr_sample.wav") +result = pipe.generate(raw_speech, task="translate") +# It is a very complex entity that consists... +``` + +### Timestamps prediction + +The model can predict timestamps. For sentence-level timestamps, pass the `return_timestamps` argument: + +```python +raw_speech = read_wav("how_are_you_doing_today.wav") +result = pipe.generate(raw_speech, return_timestamps=True) + +for chunk in result.chunks: + print(f"timestamps: [{chunk.start_ts:.2f}, {chunk.end_ts:.2f}] text: {chunk.text}") +# timestamps: [0.00, 2.00] text: How are you doing today? +``` + +### Long-Form audio Transcription + +The Whisper model is designed to work on audio samples of up to 30s in duration. Whisper pipeline uses sequential chunking algorithm to transcribe audio samples of arbitrary length. +Sequential chunking algorithm uses a "sliding window", transcribing 30-second slices one after the other. + +### Initial prompt and hotwords + +Whisper pipeline has `initial_prompt` and `hotwords` generate arguments: +* `initial_prompt`: initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing window +* `hotwords`: hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows + +The Whisper model can use that context to better understand the speech and maintain a consistent writing style. However, prompts do not need to be genuine transcripts from prior audio segments. Such prompts can be used to steer the model to use particular spellings or styles: + +```python +result = pipe.generate(raw_speech) +# He has gone and gone for good answered Paul Icrom who... + +result = pipe.generate(raw_speech, initial_prompt="Polychrome") +# He has gone and gone for good answered Polychrome who... +``` + ### Troubleshooting #### Empty or rubbish output diff --git a/samples/python/whisper_speech_recognition/whisper_speech_recognition.py b/samples/python/whisper_speech_recognition/whisper_speech_recognition.py index 3fddfc8ffa..9cf3be5fa1 100755 --- a/samples/python/whisper_speech_recognition/whisper_speech_recognition.py +++ b/samples/python/whisper_speech_recognition/whisper_speech_recognition.py @@ -18,7 +18,7 @@ def main(): parser.add_argument("wav_file_path") args = parser.parse_args() - device = "CPU" # GPU can be used as well + device = "CPU" # GPU, NPU can be used as well pipe = openvino_genai.WhisperPipeline(args.model_dir, device) config = pipe.get_generation_config() @@ -34,8 +34,9 @@ def main(): print(result) - for chunk in result.chunks: - print(f"timestamps: [{chunk.start_ts}, {chunk.end_ts}] text: {chunk.text}") + if result.chunks: + for chunk in result.chunks: + print(f"timestamps: [{chunk.start_ts:.2f}, {chunk.end_ts:.2f}] text: {chunk.text}") if "__main__" == __name__: diff --git a/src/cpp/include/openvino/genai/image_generation/scheduler.hpp b/src/cpp/include/openvino/genai/image_generation/scheduler.hpp index 21c266aa50..25c5e07a2f 100644 --- a/src/cpp/include/openvino/genai/image_generation/scheduler.hpp +++ b/src/cpp/include/openvino/genai/image_generation/scheduler.hpp @@ -19,7 +19,8 @@ class OPENVINO_GENAI_EXPORTS Scheduler { DDIM, EULER_DISCRETE, FLOW_MATCH_EULER_DISCRETE, - PNDM + PNDM, + EULER_ANCESTRAL_DISCRETE }; static std::shared_ptr from_config(const std::filesystem::path& scheduler_config_path, diff --git a/src/cpp/include/openvino/genai/whisper_generation_config.hpp b/src/cpp/include/openvino/genai/whisper_generation_config.hpp index 37b23cde74..44d611923d 100644 --- a/src/cpp/include/openvino/genai/whisper_generation_config.hpp +++ b/src/cpp/include/openvino/genai/whisper_generation_config.hpp @@ -3,8 +3,8 @@ #pragma once -#include #include +#include #include "openvino/genai/tokenizer.hpp" #include "openvino/runtime/compiled_model.hpp" @@ -46,6 +46,9 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig { // Transcribe token id. int64_t transcribe_token_id = 50359; + // Corresponds to the ”<|startofprev|>” token. + int64_t prev_sot_token_id = 50361; + // No timestamps token id. int64_t no_timestamps_token_id = 50363; @@ -75,6 +78,32 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig { // Note that a segment of text refers to a sequence of one or more words, rather than individual words. bool return_timestamps = false; + /* + * Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing + * window. Can be used to steer the model to use particular spellings or styles. + * + * Example: + * auto result = pipeline.generate(raw_speech); + * // He has gone and gone for good answered Paul Icrom who... + * + * auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); + * // He has gone and gone for good answered Polychrome who... + */ + std::optional initial_prompt = std::nullopt; + + /* + * Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows. + * Can be used to steer the model to use particular spellings or styles. + * + * Example: + * auto result = pipeline.generate(raw_speech); + * // He has gone and gone for good answered Paul Icrom who... + * + * auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); + * // He has gone and gone for good answered Polychrome who... + */ + std::optional hotwords = std::nullopt; + // A list containing tokens that will be suppressed at the beginning of the sampling process. std::vector begin_suppress_tokens; @@ -111,9 +140,12 @@ static constexpr ov::Property pad_token_id{"pad_token_id"}; static constexpr ov::Property transcribe_token_id{"transcribe_token_id"}; static constexpr ov::Property translate_token_id{"translate_token_id"}; static constexpr ov::Property no_timestamps_token_id{"no_timestamps_token_id"}; +static constexpr ov::Property prev_sot_token_id{"prev_sot_token_id"}; static constexpr ov::Property language{"language"}; static constexpr ov::Property task{"task"}; static constexpr ov::Property return_timestamps{"return_timestamps"}; +static constexpr ov::Property initial_prompt{"initial_prompt"}; +static constexpr ov::Property hotwords{"hotwords"}; static constexpr ov::Property> lang_to_id{"lang_to_id"}; } // namespace genai diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 6e7e982a4c..e1ffd062de 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -22,7 +22,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( m_tokenizer = tokenizer; m_generation_config = generation_config; m_is_validation_mode_enabled = is_validation_mode_enabled; - + ov::Core core; auto [core_properties, compile_properties] = utils::split_core_compile_config(properties); @@ -255,18 +255,6 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generations; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id])); - } - - std::vector results; - results.reserve(m_awaiting_requests.size()); - auto drop_requests = [&] () { for (const std::shared_ptr request : m_requests) { for (const auto& sequence: request->get_sequences()) { @@ -279,25 +267,40 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generations; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id])); + } + auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished + bool continue_generation = true; while (has_non_finished_requests() && continue_generation) { try { step(); } catch (...) { - drop_requests(); + drop_requests(); // remove all requests from pipeline state in case of exception throw; } - if (streamer_ptr && generations.at(0)->can_read()) { - std::unordered_map token = generations.at(0).get()->back(); + + auto & generation = generations.at(0); + if (streamer_ptr && generation->can_read()) { + std::unordered_map token = generation->back(); for (const auto& gen_token : token.begin()->second.generated_ids) { - if (!streamer_ptr->put(gen_token)) { + continue_generation = !streamer_ptr->put(gen_token); + if (!continue_generation) { + generation->drop(); break; } } } } - if (streamer_ptr) { + if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } @@ -307,16 +310,32 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector results; + results.reserve(all_requests.size()); + + for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { + const auto& request = all_requests[request_id]; + auto sampling_params = request->get_sampling_parameters(); + const auto& sequences = request->get_finished_sequences(); + size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); + EncodedGenerationResult result; - result.m_request_id = 1; - std::vector generation_outputs = generation->read_all(); - for (const auto& generation_output : generation_outputs) { - result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); - result.m_scores.push_back(generation_output.score); + result.m_request_id = request_id; + result.m_generation_ids.resize(num_outputs); + result.m_scores.resize(num_outputs); + + for (size_t i = 0; i < num_outputs; ++i) { + const auto & sequence = sequences[i]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs(); + const auto & generated_ids = sequence->get_generated_ids(); + + if (sampling_params.echo) + result.m_generation_ids[i] = request->get_prompt_ids(); + std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; } - result.m_status = generation->get_status(); + + result.m_status = generations[request_id]->get_status(); results.push_back(std::move(result)); } @@ -408,7 +427,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) { SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; // requests not scheduled, in decoding phase or not echoing are not processed - if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() || + if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() || !sequence_group->get_sampling_parameters().echo) continue; @@ -421,10 +440,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens(); OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len()); - + // if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion) - // otherwise we include it as it will be used in the next part of the prompt - int exclude_last_logprob = 1; + // otherwise we include it as it will be used in the next part of the prompt + int exclude_last_logprob = 1; if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len()) exclude_last_logprob = 0; @@ -435,7 +454,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1; token_logits_offset < actual_seq_len - exclude_last_logprob; token_logits_offset++, token_id_offset++) { - + const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size); int64_t token_id = sequence_group->get_prompt_ids()[token_id_offset]; float token_logit = token_logits[token_id]; diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index a1dd467523..0f10a85a86 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -17,7 +17,7 @@ GenerationStatus GenerationHandleImpl::get_status() { } bool GenerationHandleImpl::can_read() { - return !is_dropped() && m_generation_stream->can_read(); + return !is_dropped() && m_generation_stream->can_read(); } bool GenerationHandleImpl::is_dropped() { diff --git a/src/cpp/src/generation_stream.hpp b/src/cpp/src/generation_stream.hpp index 4d41f160e4..518699ba36 100644 --- a/src/cpp/src/generation_stream.hpp +++ b/src/cpp/src/generation_stream.hpp @@ -14,8 +14,6 @@ class GenerationStream { GenerationStatus m_status = GenerationStatus::RUNNING; SynchronizedQueue m_output_queue; - std::vector last_sequence_ids; - public: using Ptr = std::shared_ptr; @@ -30,10 +28,11 @@ class GenerationStream { m_output_queue.push(std::move(outputs)); } - // Retrieving vector of pairs as we can generate multiple outputs for a single prompt + // Retrieving vector of pairs as we can generate multiple outputs for a single prompt GenerationOutputs back() { return m_output_queue.back(); } + GenerationOutputs read() { return m_output_queue.pull(); } diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp deleted file mode 100644 index a0262c0dc8..0000000000 --- a/src/cpp/src/group_beam_searcher.cpp +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "openvino/genai/llm_pipeline.hpp" -#include "utils.hpp" -#include "lm_encoding.hpp" - -namespace { - -// Modified Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurrence in haystack -std::vector kmp_search(const std::vector& haystack, const std::vector& needle) { - if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token - return {haystack.begin(), haystack.end()}; - } - std::vector partial_match_table(needle.size() + 1, -1); - int cnd = 0; - for (size_t pos = 1; pos < needle.size(); ++pos) { - if (needle.at(pos) == needle.at(size_t(cnd))) { - partial_match_table.at(pos) = partial_match_table.at(size_t(cnd)); - } else { - partial_match_table.at(pos) = cnd; - while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) { - cnd = partial_match_table.at(size_t(cnd)); - } - } - ++cnd; - } - partial_match_table.back() = cnd; - std::vector res; - size_t haystack_id = 0; - int needle_id = 0; - while (haystack_id < haystack.size() - 1) { - if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) { - ++haystack_id; - ++needle_id; - if (needle_id == int(needle.size())) { - res.push_back(haystack.at(haystack_id)); - needle_id = partial_match_table.at(size_t(needle_id)); - } - } else { - needle_id = partial_match_table.at(size_t(needle_id)); - if (needle_id < 0) { - ++haystack_id; - ++needle_id; - } - } - } - return res; -} - -struct Token { - float log_prob; - int64_t idx; -}; - -std::vector log_softmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape().at(0) <= batch_idx) { - throw std::runtime_error("logits batch size doesn't match the number of beams"); - } - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; - size_t sequence_offset = (logits.get_shape().at(1) - 1) * vocab_size; - const float* beam_logits = logits.data() + batch_offset + sequence_offset; - float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size); - float log_sum = std::log( - std::accumulate(beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) { - return accumulated + std::exp(to_add - max_logit); - })); - std::vector tokens; - tokens.reserve(vocab_size); - for (size_t idx = 0; idx < vocab_size; ++idx) { - tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)}); - } - return tokens; -} - -struct Beam { - float score = -std::numeric_limits::infinity(); // The bigger, the better - std::vector tokens; - size_t global_beam_idx = 0; -}; - -bool greater(const Beam& left, const Beam& right) { - return left.score > right.score; -} - -struct Parameters { - std::vector> prompts; - int64_t eos_token_id; - size_t n_groups = 3; - size_t group_size = 5; - float diversity_penalty = 1.0; - size_t max_new_tokens = 20; - ov::genai::StopCriteria stop_criteria = ov::genai::StopCriteria::HEURISTIC; - float length_penalty = 1.0; - size_t no_repeat_ngram_size = std::numeric_limits::max(); - - std::function early_finish = [](const Beam&) { - return false; - }; -}; - -struct Group { - std::vector ongoing; // Best beams in front - std::vector min_heap; // The worst of the best completed beams is the first - bool done = false; - - void finish(Beam&& beam, const Parameters& parameters) { - beam.score /= std::pow(float(beam.tokens.size()), parameters.length_penalty); - - min_heap.push_back(std::move(beam)); - std::push_heap(min_heap.begin(), min_heap.end(), greater); - if (min_heap.size() > parameters.group_size) { - std::pop_heap(min_heap.begin(), min_heap.end(), greater); - min_heap.pop_back(); - } - } - void is_done(const Parameters& parameters) { - if (min_heap.size() < parameters.group_size) { - return; - } - size_t cur_len = ongoing.front().tokens.size(); - float best_sum_logprobs = ongoing.front().score; - float worst_score = min_heap.front().score; - switch (parameters.stop_criteria) { - case ov::genai::StopCriteria::EARLY: - done = true; - return; - case ov::genai::StopCriteria::HEURISTIC: { - float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - case ov::genai::StopCriteria::NEVER: { - size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len; - float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - default: - throw std::runtime_error("Never reached"); - } - } -}; - -// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search -// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values -// are used for next inference. select_next_tokens() returns empty, if all groups are completed -struct GroupBeamSearcher { - Parameters parameters; - std::vector> prompts_groups; - - GroupBeamSearcher(Parameters parameters) : parameters{parameters}, prompts_groups{parameters.prompts.size()} { - if (parameters.no_repeat_ngram_size == 0) { - throw std::runtime_error("no_repeat_ngram_size must be positive"); - } - for (std::vector& prompts_groups : prompts_groups) { - prompts_groups.resize(parameters.n_groups); - for (Group& group : prompts_groups) { - group.ongoing.resize(parameters.group_size); - group.ongoing.front().score = 0.0; - } - } - } - - std::pair, std::vector> select_next_tokens(const ov::Tensor& logits) { - std::vector next_tokens; - std::vector next_beams; - - const size_t promts_size = parameters.prompts.size(); - - next_tokens.reserve(promts_size * parameters.n_groups * parameters.group_size); - next_beams.reserve(promts_size * parameters.n_groups * parameters.group_size); - - size_t beam_count = 0; - size_t prompt_id = 0; - for (std::vector& groups : prompts_groups) { - for (Group& group : groups) { - if (group.done) { - continue; - } - for (Beam& beam : group.ongoing) { - // beam.tokens.empty() holds for the first select_next_tokens() call. - // Every beam is constructed from the single batch at first call - if (beam.tokens.empty()) { - beam.global_beam_idx = prompt_id; - } else { - beam.global_beam_idx = beam_count; - ++beam_count; - } - } - } - - prompt_id += 1; - } - - for (int prompt_id = 0; prompt_id < promts_size; prompt_id++) { - const std::vector prompt = parameters.prompts[prompt_id]; - std::vector& groups = prompts_groups[prompt_id]; - auto [prompt_next_tokens, prompt_next_beams] = select_prompt_next_tokens(logits, prompt, groups); - - next_tokens.insert(next_tokens.end(), prompt_next_tokens.begin(), prompt_next_tokens.end()); - next_beams.insert(next_beams.end(), prompt_next_beams.begin(), prompt_next_beams.end()); - } - - return {next_tokens, next_beams}; - } - - std::pair, std::vector> select_prompt_next_tokens(const ov::Tensor& logits, - const std::vector& prompt, - std::vector& groups) { - std::vector next_tokens; - std::vector next_beams; - next_tokens.reserve(parameters.n_groups * parameters.group_size); - next_beams.reserve(parameters.n_groups * parameters.group_size); - - for (auto group = groups.begin(); group != groups.end(); ++group) { - if (group->done) { - continue; - } - std::vector candidates; - candidates.reserve(parameters.group_size * 2 * parameters.group_size); - for (const Beam& beam : group->ongoing) { - std::vector tokens = log_softmax(logits, beam.global_beam_idx); - for (auto prev_group = groups.cbegin(); prev_group != group; ++prev_group) { - for (const Beam& prev_beam : prev_group->ongoing) { - if (prev_beam.tokens.size() > beam.tokens.size()) { - tokens.at(size_t(prev_beam.tokens.back())).log_prob -= parameters.diversity_penalty; - } - } - } - std::vector full_text{prompt}; - full_text.insert(full_text.end(), beam.tokens.begin(), beam.tokens.end()); - if (full_text.size() > 1 && full_text.size() >= parameters.no_repeat_ngram_size) { - auto tail_start = full_text.end() - ptrdiff_t(parameters.no_repeat_ngram_size) + 1; - for (int64_t banned_token : kmp_search(full_text, {tail_start, full_text.end()})) { - tokens.at(size_t(banned_token)).log_prob = -std::numeric_limits::infinity(); - } - } - std::sort(tokens.begin(), tokens.end(), [](Token left, Token right) { - return left.log_prob > right.log_prob; // Most probable tokens in front - }); - size_t add_count = 0; - for (Token token : tokens) { - Beam new_candidate = beam; - new_candidate.score += token.log_prob; - new_candidate.tokens.push_back(token.idx); - if (parameters.early_finish(new_candidate)) { - group->finish(std::move(new_candidate), parameters); - } else { - candidates.push_back(std::move(new_candidate)); - ++add_count; - if (add_count == 2 * parameters.group_size) { - break; - } - } - } - } - // Sample 2 * group_size highest score tokens to get at least 1 non EOS token per beam - if (candidates.size() < 2 * parameters.group_size) { - throw std::runtime_error("No beams left to search"); - } - auto to_sort = candidates.begin() + ptrdiff_t(2 * parameters.group_size); - std::partial_sort(candidates.begin(), to_sort, candidates.end(), greater); - group->ongoing.clear(); - for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) { - if (parameters.eos_token_id == candidates.at(cand_idx).tokens.back()) { - // If beam_token does not belong to top num_beams tokens, it should not be added - if (cand_idx >= parameters.group_size) { - continue; - } - group->finish(std::move(candidates.at(cand_idx)), parameters); - } else { - group->ongoing.push_back(std::move(candidates.at(cand_idx))); - if (group->ongoing.size() == parameters.group_size) { - break; - } - } - } - group->is_done(parameters); - if (!group->done) { - for (const Beam& beam : group->ongoing) { - next_tokens.push_back(beam.tokens.back()); - next_beams.push_back(int32_t(beam.global_beam_idx)); - } - } - } - return {next_tokens, next_beams}; - } -}; - -// Consume group_beam_searcher because beams are consumed -std::vector>> finalize(GroupBeamSearcher&& group_beam_searcher) { - std::vector>> finalized; - finalized.resize(group_beam_searcher.prompts_groups.size()); - - for (size_t prompt_id = 0; prompt_id < group_beam_searcher.prompts_groups.size(); prompt_id++) { - std::vector& groups = group_beam_searcher.prompts_groups.at(prompt_id); - finalized.at(prompt_id).reserve(groups.size()); - - for (Group& group : groups) { - if (!group.done) { - for (Beam& beam : group.ongoing) { - group.finish(std::move(beam), group_beam_searcher.parameters); - } - } - finalized.at(prompt_id).push_back(std::move(group.min_heap)); - } - } - - return finalized; -} - -void reset_all_inputs_to_empty_tensors(ov::InferRequest& request) { - request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0})); - request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0})); - if (request.get_compiled_model().inputs().size() == 4) - request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0})); -} -} // namespace - -namespace ov { -namespace genai { - -std::pair beam_search(ov::InferRequest& lm, - ov::Tensor input_ids, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx) { - OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, - "number of beams should be divisible by number of groups"); - - auto batch_size = input_ids.get_shape().at(0); - auto sequence_length = input_ids.get_shape().at(1); - - // Initialize beam search. - const int64_t* prompt_data = input_ids.data(); - std::vector> prompts; - prompts.reserve(batch_size); - for (size_t batch = 0; batch < batch_size; batch++) { - size_t batch_offset = batch * sequence_length; - const int64_t* prompt_start = prompt_data + batch_offset; - prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); - } - - lm.set_tensor("input_ids", input_ids); - lm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) - lm.set_tensor("position_ids", *position_ids); - - ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); - lm.set_tensor("beam_idx", beam_idx); - - Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); - parameters.eos_token_id = config.eos_token_id; - parameters.n_groups = config.num_beam_groups; - parameters.group_size = config.num_beams / config.num_beam_groups; - parameters.diversity_penalty = config.diversity_penalty; - parameters.length_penalty = config.length_penalty; - parameters.stop_criteria = config.stop_criteria; - parameters.no_repeat_ngram_size = config.no_repeat_ngram_size; - GroupBeamSearcher group_beam_searcher{parameters}; - - std::vector next_tokens; - std::vector next_beams; - - // Reserve for performance counters. - std::vector new_token_times; - std::vector batch_sizes; - new_token_times.reserve(parameters.max_new_tokens); - batch_sizes.reserve(parameters.max_new_tokens); - - for (size_t length_count = 0; ; ++length_count) { - lm.infer(); - - std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); - new_token_times.emplace_back(std::chrono::steady_clock::now()); - batch_sizes.emplace_back(batch_size); - - if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { - // Break the cycle before masks are extended in update_attention_mask_with_beams. - // If generation is continued, attention_mask length should be equal to KV cache size. - break; - } - - size_t running_batch_size = next_tokens.size(); - // Set pointers - lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {running_batch_size, 1}, next_tokens.data()}); - lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {running_batch_size}, next_beams.data()}); - - // Set auxiliary inputs - update_attention_mask_with_beams(lm.get_tensor("attention_mask"), next_beams); - if (position_ids.has_value()) - update_position_ids(lm.get_tensor("position_ids"), lm.get_tensor("attention_mask")); - } - - reset_all_inputs_to_empty_tensors(lm); - - auto scores_comparator = [](Beam& left, Beam& right) { - return (left.score > right.score); - }; - - auto result = finalize(std::move(group_beam_searcher)); - ov::genai::EncodedResults results; - int32_t res_selected_beam_idx = 0; - results.scores.reserve(config.num_return_sequences * result.size()); - results.tokens.reserve(config.num_return_sequences * result.size()); - auto& raw_perf_counters = results.perf_metrics.raw_metrics; - raw_perf_counters.m_new_token_times = new_token_times; - raw_perf_counters.m_batch_sizes = batch_sizes; - - // align output with HF - for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { - auto prompt_group = result.at(prompt_id); - std::vector> plain_beams; - plain_beams.reserve(parameters.n_groups * parameters.group_size); - for (std::vector& group : prompt_group) { - for (Beam& beam : group) { - plain_beams.push_back(beam); - } - } - assert(config.num_return_sequences <= plain_beams.size()); - std::partial_sort( - plain_beams.begin(), - plain_beams.begin() + config.num_return_sequences, - plain_beams.end(), - scores_comparator - ); - res_selected_beam_idx = plain_beams.at(0).get().global_beam_idx; - for ( - auto beam = plain_beams.begin(); - beam != plain_beams.begin() + config.num_return_sequences; - ++beam - ) { - results.scores.push_back(beam->get().score); - results.tokens.push_back(std::move(beam->get().tokens)); - } - } - - return {results, res_selected_beam_idx}; -} - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp new file mode 100644 index 0000000000..a63a073cfc --- /dev/null +++ b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp @@ -0,0 +1,261 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "image_generation/schedulers/euler_ancestral_discrete.hpp" +#include "image_generation/numpy_utils.hpp" + +namespace ov { +namespace genai { + +EulerAncestralDiscreteScheduler::Config::Config(const std::filesystem::path& scheduler_config_path) { + std::ifstream file(scheduler_config_path); + OPENVINO_ASSERT(file.is_open(), "Failed to open ", scheduler_config_path); + + nlohmann::json data = nlohmann::json::parse(file); + using utils::read_json_param; + + read_json_param(data, "num_train_timesteps", num_train_timesteps); + read_json_param(data, "beta_start", beta_start); + read_json_param(data, "beta_end", beta_end); + read_json_param(data, "beta_schedule", beta_schedule); + read_json_param(data, "trained_betas", trained_betas); + read_json_param(data, "steps_offset", steps_offset); + read_json_param(data, "prediction_type", prediction_type); + read_json_param(data, "timestep_spacing", timestep_spacing); + read_json_param(data, "rescale_betas_zero_snr", rescale_betas_zero_snr); +} + +EulerAncestralDiscreteScheduler::EulerAncestralDiscreteScheduler(const std::filesystem::path& scheduler_config_path) + : EulerAncestralDiscreteScheduler(Config(scheduler_config_path)) { +} + +EulerAncestralDiscreteScheduler::EulerAncestralDiscreteScheduler(const Config& scheduler_config): m_config(scheduler_config) { + std::vector alphas, betas; + + using numpy_utils::linspace; + + if (!m_config.trained_betas.empty()) { + betas = m_config.trained_betas; + } else if (m_config.beta_schedule == BetaSchedule::LINEAR) { + betas = linspace(m_config.beta_start, m_config.beta_end, m_config.num_train_timesteps); + } else if (m_config.beta_schedule == BetaSchedule::SCALED_LINEAR) { + float start = std::sqrt(m_config.beta_start); + float end = std::sqrt(m_config.beta_end); + betas = linspace(start, end, m_config.num_train_timesteps); + std::for_each(betas.begin(), betas.end(), [](float& x) { + x *= x; + }); + // TODO: else if beta_schedule == "squaredcos_cap_v2" + } else { + OPENVINO_THROW( + "'beta_schedule' must be one of 'LINEAR' or 'SCALED_LINEAR'. Please, add support of other types"); + } + + if (m_config.rescale_betas_zero_snr) { + using numpy_utils::rescale_zero_terminal_snr; + rescale_zero_terminal_snr(betas); + } + + std::transform(betas.begin(), betas.end(), std::back_inserter(alphas), [](float b) { + return 1.0f - b; + }); + + for (size_t i = 1; i <= alphas.size(); ++i) { + float alpha_cumprod = + std::accumulate(std::begin(alphas), std::begin(alphas) + i, 1.0, std::multiplies{}); + m_alphas_cumprod.push_back(alpha_cumprod); + } + + if (m_config.rescale_betas_zero_snr) { + m_alphas_cumprod.back() = std::pow(2, -24); + } + + for (auto it = m_alphas_cumprod.rbegin(); it != m_alphas_cumprod.rend(); ++it) { + float sigma = std::pow(((1 - (*it)) / (*it)), 0.5); + m_sigmas.push_back(sigma); + } + m_sigmas.push_back(0); + + // setable values + auto linspaced = + linspace(0.0f, static_cast(m_config.num_train_timesteps - 1), m_config.num_train_timesteps, true); + for (auto it = linspaced.rbegin(); it != linspaced.rend(); ++it) { + m_timesteps.push_back(static_cast(std::round(*it))); + } + m_num_inference_steps = -1; + m_step_index = -1; + m_begin_index = -1; + m_is_scale_input_called = false; +} + +void EulerAncestralDiscreteScheduler::set_timesteps(size_t num_inference_steps, float strength) { + m_timesteps.clear(); + m_sigmas.clear(); + m_step_index = m_begin_index = -1; + m_num_inference_steps = num_inference_steps; + std::vector sigmas; + + switch (m_config.timestep_spacing) { + case TimestepSpacing::LINSPACE: { + using numpy_utils::linspace; + float end = static_cast(m_config.num_train_timesteps - 1); + auto linspaced = linspace(0.0f, end, num_inference_steps, true); + for (auto it = linspaced.rbegin(); it != linspaced.rend(); ++it) { + m_timesteps.push_back(static_cast(std::round(*it))); + } + break; + } + case TimestepSpacing::LEADING: { + size_t step_ratio = m_config.num_train_timesteps / m_num_inference_steps; + for (size_t i = num_inference_steps - 1; i != -1; --i) { + m_timesteps.push_back(i * step_ratio + m_config.steps_offset); + } + break; + } + case TimestepSpacing::TRAILING: { + float step_ratio = static_cast(m_config.num_train_timesteps) / static_cast(m_num_inference_steps); + for (float i = m_config.num_train_timesteps; i > 0; i -= step_ratio) { + m_timesteps.push_back(static_cast(std::round(i)) - 1); + } + break; + } + default: + OPENVINO_THROW("Unsupported value for 'timestep_spacing'"); + } + + for (const float& i : m_alphas_cumprod) { + float sigma = std::pow(((1 - i) / i), 0.5); + sigmas.push_back(sigma); + } + + using numpy_utils::interp; + std::vector x_data_points(sigmas.size()); + std::iota(x_data_points.begin(), x_data_points.end(), 0); + m_sigmas = interp(m_timesteps, x_data_points, sigmas); + m_sigmas.push_back(0.0f); + + // apply 'strength' used in image generation + // in diffusers, it's https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L650 + { + size_t init_timestep = std::min(num_inference_steps * strength, num_inference_steps); + size_t t_start = std::max(num_inference_steps - init_timestep, 0); + // keep original timesteps + m_schedule_timesteps = m_timesteps; + // while return patched ones by 'strength' parameter + m_timesteps = std::vector(m_timesteps.begin() + t_start, m_timesteps.end()); + m_begin_index = t_start; + } +} + +std::map EulerAncestralDiscreteScheduler::step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr generator) { + // noise_pred - model_output + // latents - sample + // inference_step + + size_t timestep = m_timesteps[inference_step]; + + if (m_step_index == -1) + m_step_index = m_begin_index; + + float sigma = m_sigmas[m_step_index]; + + float* model_output_data = noise_pred.data(); + float* sample_data = latents.data(); + + ov::Tensor pred_original_sample(noise_pred.get_element_type(), noise_pred.get_shape()); + float* pred_original_sample_data = pred_original_sample.data(); + + switch (m_config.prediction_type) { + case PredictionType::EPSILON: + for (size_t i = 0; i < noise_pred.get_size(); ++i) { + pred_original_sample_data[i] = sample_data[i] - sigma * model_output_data[i]; + } + break; + case PredictionType::V_PREDICTION: + for (size_t i = 0; i < noise_pred.get_size(); ++i) { + pred_original_sample_data[i] = model_output_data[i] * (-sigma / std::pow((std::pow(sigma, 2) + 1), 0.5)) + + (sample_data[i] / (std::pow(sigma, 2) + 1)); + } + break; + default: + OPENVINO_THROW("Unsupported value for 'PredictionType': must be one of `epsilon`, or `v_prediction`"); + } + + float sigma_from = m_sigmas[m_step_index]; + float sigma_to = m_sigmas[m_step_index + 1]; + float sigma_up = std::sqrt(std::pow(sigma_to, 2) * (std::pow(sigma_from, 2) - std::pow(sigma_to, 2)) / std::pow(sigma_from, 2)); + float sigma_down = std::sqrt(std::pow(sigma_to, 2) - std::pow(sigma_up, 2)); + float dt = sigma_down - sigma; + + ov::Tensor prev_sample = ov::Tensor(latents.get_element_type(), latents.get_shape()); + float* prev_sample_data = prev_sample.data(); + + ov::Tensor noise = generator->randn_tensor(noise_pred.get_shape()); + const float* noise_data = noise.data(); + + for (size_t i = 0; i < prev_sample.get_size(); ++i) { + float derivative = (sample_data[i] - pred_original_sample_data[i]) / sigma; + prev_sample_data[i] = (sample_data[i] + derivative * dt) + noise_data[i] * sigma_up; + } + + m_step_index++; + + return {{"latent", prev_sample}, {"denoised", pred_original_sample}}; +} + +size_t EulerAncestralDiscreteScheduler::_index_for_timestep(int64_t timestep) const{ + for (size_t i = 0; i < m_schedule_timesteps.size(); ++i) { + if (timestep == m_schedule_timesteps[i]) { + return i; + } + } + + OPENVINO_THROW("Failed to find index for timestep ", timestep); +} + +void EulerAncestralDiscreteScheduler::add_noise(ov::Tensor init_latent, ov::Tensor noise, int64_t latent_timestep) const { + size_t index_for_timestep = _index_for_timestep(latent_timestep); + const float sigma = m_sigmas[index_for_timestep]; + + float * init_latent_data = init_latent.data(); + const float * noise_data = noise.data(); + + for (size_t i = 0; i < init_latent.get_size(); ++i) { + init_latent_data[i] = init_latent_data[i] + sigma * noise_data[i]; + } +} + +std::vector EulerAncestralDiscreteScheduler::get_timesteps() const { + return m_timesteps; +} + +void EulerAncestralDiscreteScheduler::scale_model_input(ov::Tensor sample, size_t inference_step) { + if (m_step_index == -1) + m_step_index = m_begin_index; + + float sigma = m_sigmas[m_step_index]; + float* sample_data = sample.data(); + for (size_t i = 0; i < sample.get_size(); i++) { + sample_data[i] /= std::pow((std::pow(sigma, 2) + 1), 0.5); + } + m_is_scale_input_called = true; +} + +float EulerAncestralDiscreteScheduler::get_init_noise_sigma() const { + float max_sigma = *std::max_element(m_sigmas.begin(), m_sigmas.end()); + + if (m_config.timestep_spacing == TimestepSpacing::LINSPACE || + m_config.timestep_spacing == TimestepSpacing::TRAILING) { + return max_sigma; + } + + return std::sqrt(std::pow(max_sigma, 2) + 1); +} + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp new file mode 100644 index 0000000000..9d82c9a0a9 --- /dev/null +++ b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp @@ -0,0 +1,61 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "image_generation/schedulers/types.hpp" +#include "image_generation/schedulers/ischeduler.hpp" + +namespace ov { +namespace genai { + +class EulerAncestralDiscreteScheduler : public IScheduler { +public: + struct Config { + int32_t num_train_timesteps = 1000; + float beta_start = 0.0001f, beta_end = 0.02f; + BetaSchedule beta_schedule = BetaSchedule::LINEAR; + std::vector trained_betas = {}; + size_t steps_offset = 0; + PredictionType prediction_type = PredictionType::EPSILON; + TimestepSpacing timestep_spacing = TimestepSpacing::LEADING; + bool rescale_betas_zero_snr = false; + + Config() = default; + explicit Config(const std::filesystem::path& scheduler_config_path); + }; + + explicit EulerAncestralDiscreteScheduler(const std::filesystem::path& scheduler_config_path); + explicit EulerAncestralDiscreteScheduler(const Config& scheduler_config); + + void set_timesteps(size_t num_inference_steps, float strength) override; + + std::vector get_timesteps() const override; + + float get_init_noise_sigma() const override; + + void scale_model_input(ov::Tensor sample, size_t inference_step) override; + + std::map step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr generator) override; + + void add_noise(ov::Tensor init_latent, ov::Tensor noise, int64_t latent_timestep) const override; + +private: + Config m_config; + + std::vector m_alphas_cumprod, m_sigmas; + std::vector m_timesteps, m_schedule_timesteps; + size_t m_num_inference_steps; + + int m_step_index, m_begin_index; + bool m_is_scale_input_called; + + size_t _index_for_timestep(int64_t timestep) const; +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/image_generation/schedulers/scheduler.cpp b/src/cpp/src/image_generation/schedulers/scheduler.cpp index f9cd098346..868f6f05cf 100644 --- a/src/cpp/src/image_generation/schedulers/scheduler.cpp +++ b/src/cpp/src/image_generation/schedulers/scheduler.cpp @@ -11,6 +11,7 @@ #include "image_generation/schedulers/euler_discrete.hpp" #include "image_generation/schedulers/flow_match_euler_discrete.hpp" #include "image_generation/schedulers/pndm.hpp" +#include "image_generation/schedulers/euler_ancestral_discrete.hpp" namespace ov { namespace genai { @@ -41,6 +42,8 @@ std::shared_ptr Scheduler::from_config(const std::filesystem::path& s scheduler = std::make_shared(scheduler_config_path); } else if (scheduler_type == Scheduler::Type::PNDM) { scheduler = std::make_shared(scheduler_config_path); + } else if (scheduler_type == Scheduler::Type::EULER_ANCESTRAL_DISCRETE) { + scheduler = std::make_shared(scheduler_config_path); } else { OPENVINO_THROW("Unsupported scheduler type '", scheduler_type, ". Please, manually create scheduler via supported one"); } diff --git a/src/cpp/src/image_generation/schedulers/types.cpp b/src/cpp/src/image_generation/schedulers/types.cpp index 2f7c6d3f25..5a9e5b6865 100644 --- a/src/cpp/src/image_generation/schedulers/types.cpp +++ b/src/cpp/src/image_generation/schedulers/types.cpp @@ -57,6 +57,8 @@ void read_json_param(const nlohmann::json& data, const std::string& name, Schedu param = Scheduler::FLOW_MATCH_EULER_DISCRETE; else if (scheduler_type_str == "PNDMScheduler") param = Scheduler::PNDM; + else if (scheduler_type_str == "EulerAncestralDiscreteScheduler") + param = Scheduler::EULER_ANCESTRAL_DISCRETE; else if (!scheduler_type_str.empty()) { OPENVINO_THROW("Unsupported value for 'scheduler' ", scheduler_type_str); } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 6d9aae30fa..33180a9199 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -24,27 +24,23 @@ namespace ov { namespace genai { -std::pair beam_search( - ov::InferRequest& lm, - ov::Tensor prompts, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx -); - class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; bool is_chat_conversation = false; bool m_trust_encoded_history = true; - std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; std::vector m_tokenized_chat_history; ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; - size_t m_to_remove_from_hist = 0; size_t m_kv_cache_seq_length_axis = 2; + Sampler m_sampler; + // Tail of previous output in chat mode is missing in KV cache, let's keep it + std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; StatefulLLMPipeline( const ov::InferRequest& request, @@ -75,7 +71,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { const std::string& device, const ov::AnyMap& config, const ov::genai::GenerationConfig& generation_config - ) : LLMPipelineImplBase(tokenizer, generation_config) { + ) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) { ov::Core core; ov::CompiledModel compiled_model; auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config); @@ -96,6 +92,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // If eos_token_id was not provided, take value if (m_generation_config.eos_token_id == -1) m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + + m_sampler.set_seed(m_generation_config.rng_seed); } StatefulLLMPipeline( @@ -151,35 +149,44 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; + size_t trusted_history_length = 0; if (!m_tokenized_chat_history.empty()) { std::set stop_tokens = config.stop_token_ids; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); - m_trust_encoded_history = last_same_hist_token == SIZE_MAX; + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + m_trust_encoded_history = trusted_history_length == SIZE_MAX; } if (m_tokenized_chat_history.empty()) { encoded_input = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.input_ids.data() + last_same_hist_token); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.input_ids.data() + trusted_history_length); ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape()); std::fill_n(new_attention_mask.data(), new_tensor.get_shape()[1], 1); encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}); new_tensor.copy_to(encoded_input.input_ids); encoded_input.attention_mask = new_attention_mask; - - m_selected_beam = std::nullopt; + m_last_disappeared_token = std::nullopt; } else { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; + m_tokenized_chat_history.clear(); m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size()); std::copy_n(new_chat_tokens.input_ids.data(), new_chat_tokens.input_ids.get_size(), @@ -261,6 +268,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + // Tail of previous output in chat mode is missing in KV cache. + if (m_last_disappeared_token.has_value()) { + attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1); + input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token); + } + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; // If eos_token_id was not provided, take value from default m_generation_config @@ -281,10 +294,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } auto batch_size = input_ids.get_shape().at(0); - if ((batch_size != 1 || !(config.is_greedy_decoding() || config.is_multinomial())) && streamer_ptr) { - OPENVINO_THROW("Currently streaming is possible only with batch size=1 and " - "only for greedy or multinomial decoding"); - } + OPENVINO_ASSERT(streamer_ptr == nullptr || batch_size == 1 && config.num_return_sequences == 1 && + (config.is_greedy_decoding() || config.is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); auto num_inputs = m_model_runner.get_compiled_model().inputs().size(); OPENVINO_ASSERT(num_inputs == 4 || num_inputs == 3, "Model should have 3 or 4 inputs: " @@ -292,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); - ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller); + ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; @@ -302,10 +314,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // Between subsequent runs attention_mask should not be modified. auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist; + + kv_cache_len = atten_mask_history.get_shape()[1] - m_kv_history_manager.num_tokens_to_remove_from_kv_cache; ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; - auto start_atten_hst = atten_mask_history.data() + kv_cache_len * (*m_selected_beam); + auto start_atten_hst = atten_mask_history.data(); + std::copy(start_atten_hst, start_atten_hst + kv_cache_len, new_atten_mask.data()); std::copy(attention_mask.data(), attention_mask.data() + prompt_len, @@ -315,6 +329,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { concatenated_attention_mask = attention_mask; } + size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1]; + bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { @@ -328,48 +344,55 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && !m_trust_encoded_history) { m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); } - ov::genai::EncodedResults result; - if (config.is_beam_search() && is_chat_conversation) { - std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask, - config, position_ids, m_selected_beam); - } else { - std::vector requests; - size_t block_size = 1; - bool enable_prefix_caching = false; - - for (size_t request_id = 0; request_id < batch_size; request_id++) { - SequenceGroup::Ptr sequence_group; - if (is_chat_conversation) { - ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); - sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); - } else { - size_t seq_len = input_ids.get_shape().at(1); - size_t batch_offset = request_id * seq_len; - const int64_t* prompt_start = input_ids.data() + batch_offset; - std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); + std::vector requests; + size_t block_size = 1; + bool enable_prefix_caching = false; - sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); - } + for (size_t request_id = 0; request_id < batch_size; request_id++) { + SequenceGroup::Ptr sequence_group; + if (is_chat_conversation) { + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); + } else { + size_t seq_len = input_ids.get_shape().at(1); + size_t batch_offset = request_id * seq_len; + const int64_t* prompt_start = input_ids.data() + batch_offset; + std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - sequence_group->set_sequence_group_ptr(sequence_group); - requests.push_back(sequence_group); + sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); } - Sampler sampler = Sampler(m_tokenizer); - std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, - sampler, requests, position_ids, std::nullopt, m_selected_beam); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); } + if (m_sampler.get_seed() != config.rng_seed) { + m_sampler.set_seed(config.rng_seed); + } + + ov::genai::EncodedResults result; + std::tie(result, m_last_disappeared_token) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, + streamer_ptr, m_sampler, requests, position_ids, std::nullopt); + if (is_chat_conversation) { + // force remove from kv_cache last answer + if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { + m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; + } + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); } else { reset_kv_state(); - m_selected_beam = std::nullopt; + m_last_disappeared_token = std::nullopt; } + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -383,10 +406,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history = {}; @@ -404,10 +427,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history.clear(); @@ -581,9 +604,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector plain_replies; std::vector plain_scores; for (GenerationResult& res : generated) { - if (GenerationStatus::FINISHED != res.m_status) { - OPENVINO_THROW("Got unfinished GenerationStatus"); - } + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } @@ -639,9 +660,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector> plain_tokens; std::vector plain_scores; for (EncodedGenerationResult& res : generated) { - if (GenerationStatus::FINISHED != res.m_status) { - OPENVINO_THROW("Got unfinished GenerationStatus"); - } + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 090aed9650..6f4f124894 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -407,7 +407,8 @@ ov::genai::ModelConfigDesc get_modeldesc_from_json(const std::filesystem::path& if (config_data.contains("_name_or_path")) { desc.name_or_path = config_data["_name_or_path"].get(); } - desc.num_key_value_heads = config_data["num_key_value_heads"].get(); + desc.num_key_value_heads = config_data.contains("num_key_value_heads") + ? config_data["num_key_value_heads"].get() : -1; return desc; } @@ -1102,6 +1103,11 @@ EncodedResults StaticLLMPipeline::generate( m_kvcache_request.get_tensor(output_name).copy_to(kvcache_in_slice); } } + + if (streamer_ptr) { + streamer_ptr->end(); + } + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. auto& metrics = results.perf_metrics; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 3ab041fa58..17a20dd961 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -9,12 +9,11 @@ #include #include +#include "utils.hpp" +#include "debug_utils.hpp" #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "debug_utils.hpp" - -#include "utils.hpp" namespace ov { namespace genai { @@ -51,7 +50,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector get_lm_encoded_results( +std::pair> get_lm_encoded_results( ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, @@ -59,41 +58,56 @@ std::pair get_lm_encoded_results( Sampler& sampler, std::vector sequence_groups, std::optional position_ids, - std::optional m_embedding, - std::optional selected_beam_idx + std::optional m_embedding ) { std::vector generations; for (SequenceGroup::Ptr sequence_group : sequence_groups) { generations.push_back(std::make_shared(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters())); } + auto active_sequence_groups{sequence_groups}; + + auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() { + GenerationHandle& handle = generations.at(0); + if (streamer_ptr && handle->can_read()) { + std::unordered_map token = handle->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + handle->drop(); + break; + } + } + } + + // free non running requests + auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(), + [](SequenceGroup::Ptr sg) -> bool { + return sg->has_finished() || sg->out_of_memory() || sg->handle_dropped(); + }); + active_sequence_groups.erase(removed_it, active_sequence_groups.end()); + }; + ov::Shape prompts_shape = input_ids.get_shape(); const size_t batch_size = prompts_shape[0]; // Initialize results and performance metrics. + EncodedResults results; auto& raw_perf_counters = results.perf_metrics.raw_metrics; raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }}; // Initialize inputs - if (m_embedding.has_value()) - m_llm.set_tensor("inputs_embeds", input_ids); - else - m_llm.set_tensor("input_ids", input_ids); - + m_llm.set_tensor(m_embedding.has_value() ? "inputs_embeds" : "input_ids", input_ids); m_llm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) m_llm.set_tensor("position_ids", *position_ids); ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); + std::fill_n(beam_idx.data(), batch_size, 0); m_llm.set_tensor("beam_idx", beam_idx); + // "Prompt" phase + const auto infer_start = std::chrono::steady_clock::now(); m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); @@ -109,7 +123,6 @@ std::pair get_lm_encoded_results( for (auto& sequence_group : sequence_groups) { sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len); sequence_group->schedule_tokens(sequence_len); - } std::map beam_offets; @@ -117,27 +130,11 @@ std::pair get_lm_encoded_results( beam_offets.insert({sequence_groups.at(i)->get_request_id(), i}); SamplerOutput sampler_output = sampler.sample(sequence_groups, logits); + stream_generated_tokens(); - auto active_sequence_groups{sequence_groups}; - auto get_active_sequence_groups = [](SequenceGroup::Ptr sg) { return sg->has_finished(); }; - - active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(), - active_sequence_groups.end(), - get_active_sequence_groups), - active_sequence_groups.end()); - - auto stream_generated_tokens = [&streamer_ptr, &generations]() { - if (streamer_ptr && generations.at(0).get()->can_read()) { - std::unordered_map token = generations.at(0).get()->back(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (!streamer_ptr->put(gen_token)) { - break; - } - } - } - }; + // "Generation" phase - while (active_sequence_groups.size() > 0) { + while (!active_sequence_groups.empty()) { size_t total_num_tokens = 0; for (auto& sequence_group : active_sequence_groups) { @@ -172,26 +169,19 @@ std::pair get_lm_encoded_results( // apply strides to shift to a next sequence input_ids_data += num_scheduled_tokens; - // for different sequences iteration of beams started from 0, but we collect it to one input_ids# + // for different sequences iteration of beams started from 0, but we collect it to one input_ids next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id())); } } - for (size_t i = 0; i < sequence_groups.size(); i++) { - if (i == 0) - beam_offets[sequence_groups.at(i)->get_request_id()] = 0; - else { - beam_offets[sequence_groups.at(i)->get_request_id()] = sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i -1]; - } + for (size_t i = 0; i < active_sequence_groups.size(); i++) { + beam_offets[active_sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]); } if (m_embedding.has_value()) { const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids); - - m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape()); m_llm.set_tensor("inputs_embeds", embed_prompt_tensor); } else { - m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape()); m_llm.set_tensor("input_ids", new_input_ids); } @@ -201,7 +191,6 @@ std::pair get_lm_encoded_results( update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask")); } - m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens }); m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); const auto infer_start = std::chrono::steady_clock::now(); @@ -213,42 +202,38 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); - stream_generated_tokens(); - sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits")); - - active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(), - active_sequence_groups.end(), - get_active_sequence_groups), - active_sequence_groups.end()); + stream_generated_tokens(); } - // to stream last token - stream_generated_tokens(); - if (streamer_ptr) { + if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } - - size_t next_selected_beam = 0; - for (size_t i = 0; i < sequence_groups.size(); i++) { - auto request = sequence_groups[i]; - auto generation_outputs = generations[i]->read_all(); - - std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; - results.tokens.push_back(std::move(generation_output.generated_ids)); - results.scores.push_back(generation_output.score); + + for (auto& sequence_group : sequence_groups) { + auto sampling_params = sequence_group->get_sampling_parameters(); + const auto& sequences = sequence_group->get_finished_sequences(); + size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size()); + + for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) { + const auto & sequence = sequences[seq_id]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs(); + + results.tokens.push_back(sequence->get_generated_ids()); + results.scores.push_back(score); } - // next_selected_beam = sampler.last_selected_beam(request); } - return {results, next_selected_beam}; + for (SequenceGroup::Ptr sequence_group : sequence_groups) + sampler.clear_request_info(sequence_group->get_request_id()); + + // it is not saved in KV cache, we need to add it for some cases + std::optional last_token_of_best_sequence = std::nullopt; + if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped()) + last_token_of_best_sequence = results.tokens[0].back(); + + return {results, last_token_of_best_sequence}; } } // namespace genai -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index fa6692ede0..c31cffb9bc 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,13 +8,9 @@ namespace ov { namespace genai { -std::pair get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, - const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, - std::optional position_ids, std::optional m_embedding, std::optional selected_beam_idx); - -void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams); - -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask); +std::pair> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, + const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, + std::optional position_ids, std::optional m_embedding); } } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index f77463d767..9c18dc7721 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -85,75 +85,63 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin return clean_text; } +std::vector encode_and_process_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) { + // encode stop_string + std::string stop_string_copy = stop_string; + ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string_copy, ov::genai::add_special_tokens(false)).input_ids; + size_t tensor_size = ov_encoded_stop_string.get_size(); + std::vector encoded_stop_string(tensor_size); + std::copy_n(ov_encoded_stop_string.data(), tensor_size, encoded_stop_string.begin()); + return encoded_stop_string; +} + +struct MatchStopStringResult { + size_t to_remove = 0; + // int64_t last_token_id = 0; + // bool is_to_update_last_token = false; + bool is_matched = false; +}; + // Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. -int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set & stop_strings) { - /* - For catching stop_string hit we run comparisons character-wise to catch cases where stop string - overlaps with part of another token on both sides or is just a part of a single token. - For every stop_string we iterate over generated tokens starting from the last one and going backwards. - Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token. - After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token. - Its characters are compared to the stop_string character at a current_position - (position of a character in the stop_string counting from the last one) - at the beginning position is 0. - When characters match we increase current_position and check if we have a full match already, if not we continue. - If we have already matched some characters (current_position > 0) and next character is not matching - before we reach the full match, then we reset current_position to 0. - */ - std::string prefix = "a"; - auto prefix_ov = tokenizer.encode(prefix).input_ids; - std::vector prefix_tokens(prefix_ov.data(), prefix_ov.data() + prefix_ov.get_size()); - std::string suffix = "b"; - auto suffix_ov = tokenizer.encode(suffix).input_ids; - std::vector suffix_tokens(suffix_ov.data(), suffix_ov.data() + suffix_ov.get_size()); - - // Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here - // and get suffix string that will actually be part of the decoded string so we can remove it correctly - auto wrapped_suffix_tokens = suffix_tokens; - wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end()); - std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens); - auto wrapper_pos = wrapped_suffix.find(prefix); - suffix = wrapped_suffix.substr(wrapper_pos + prefix.size()); - - for (auto stop_string: stop_strings) { - int current_position = 0; - int num_matched_tokens = 0; - // Getting reverse iterator to check tokens starting from the last one generated and going backwards - auto generated_tokens_rit = generated_tokens.rbegin(); - std::vector tokens_buffer; - while (generated_tokens_rit != generated_tokens.rend()) { - num_matched_tokens++; - tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit); - - std::vector wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens); - std::string wrapped_text = tokenizer.decode(wrapped_tokens); - std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix); - - if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "�") == 0))) { - generated_tokens_rit++; - continue; - } else { - tokens_buffer.clear(); - } - // Checking clean_text characters starting from the last one - for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) { - // On character match increment current_position for the next comparisons - if (*clean_text_rit == *(stop_string.rbegin() + current_position)) { - current_position++; - // If this is the last character from the stop_string we have a match - if ((stop_string.rbegin() + current_position) == stop_string.rend()) { - return num_matched_tokens; - } - } else if (current_position) { - // Already found matching characters, but the last one didn't match, so we reset current_position - current_position = 0; - // Looking for the match will start over from this character so we decrement iterator - clean_text_rit--; +MatchStopStringResult match_stop_string(Tokenizer& tokenizer, + const TokenIds& generated_tokens, + const std::pair>& stop_strings, + bool is_include_to_output) { + MatchStopStringResult result; + if (generated_tokens.size() >= stop_strings.first) { + size_t offset = generated_tokens.size() - stop_strings.first; + TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end()); + std::string decoded_buffer = tokenizer.decode(buffer); + for (const auto& stop_string : stop_strings.second) { + auto pos = decoded_buffer.find(stop_string); + if (pos != std::string::npos) { + result.is_matched = true; + + auto stop_string_len = is_include_to_output ? stop_string.length() : 0; + decoded_buffer = decoded_buffer.substr(0, pos + stop_string_len); + // to remove word splitting symbols from tail + while (decoded_buffer.back() == ' ' || decoded_buffer.back() == '\n') { + decoded_buffer.pop_back(); + } + if (decoded_buffer.empty()) { + result.to_remove = buffer.size(); + return result; } + + // find token cnt to be removed from sequence by decoding token by token + std::string decoded_partially_string = ""; + for (size_t i = 0; i < buffer.size(); ++i) { + decoded_partially_string += tokenizer.decode(TokenIds{buffer[i]}); + if (decoded_partially_string.find(decoded_buffer) != std::string::npos) { + result.to_remove = buffer.size() - i - 1; + break; + } + } + return result; } - generated_tokens_rit++; } } - return 0; + return result; } // Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. @@ -245,7 +233,9 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } -void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) { +void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, + SamplerOutput& sampler_output, + const std::pair>& stop_strings) { assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups; @@ -392,19 +382,17 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa // There's probably a better way to do that, than copying whole vector... std::vector token_ids = candidate.m_sequence->get_generated_ids(); token_ids.push_back(candidate.m_token_id); - int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings); - if (num_last_matched_tokens) { + auto match_result = match_stop_string(m_tokenizer, token_ids, stop_strings, m_parameters.include_stop_str_in_output); + if (match_result.is_matched) { // If beam_token does not belong to top num_beams tokens, it should not be added if (cand_idx >= group_size) continue; - if(!m_parameters.include_stop_str_in_output) { - // remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point) - candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1); - } + // remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point) + candidate.m_sequence->remove_last_tokens(match_result.to_remove); // try to finish candidate - try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output); + try_to_finish_candidate(group, candidate); continue; } } @@ -576,10 +564,11 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen } if (!sampling_params.stop_strings.empty()) { - int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); - if (num_matched_last_tokens) { - if (!sampling_params.include_stop_str_in_output) - running_sequence->remove_last_tokens(num_matched_last_tokens); + auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id()); + auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output); + if (match_result.is_matched) { + running_sequence->remove_last_tokens(match_result.to_remove); + running_sequence->set_status(SequenceStatus::FINISHED); running_sequence->set_finish_reason(GenerationFinishReason::STOP); dropped_seq_ids.push_back(running_sequence->get_id()); @@ -741,6 +730,19 @@ float get_p_prime(Sequence::Ptr& running_sequence, return p_prime; } +std::pair> +process_stop_strings(const std::set& stop_strings, Tokenizer& tokenizer) { + std::pair> result; + for (const auto& stop_string : stop_strings) { + auto encoded_stop_string = encode_and_process_string(stop_string, tokenizer); + if (result.first < encoded_stop_string.size()) { + result.first = encoded_stop_string.size(); + } + result.second.insert(stop_string); + } + return result; +} + SamplerOutput Sampler::sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled) { @@ -764,6 +766,12 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, if (!m_logit_processors.count(request_id)) { m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())}); } + if (!m_stop_strings.count(request_id)) { + auto processed_stop_string = process_stop_strings(sampling_params.stop_strings, m_tokenizer); + m_stop_strings.insert({request_id, processed_stop_string}); + sequence_group->set_stream_window_size(processed_stop_string.first); + } + auto& stop_strings = m_stop_strings.at(request_id); auto& logit_processor = m_logit_processors.at(request_id); const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens; ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data); @@ -873,7 +881,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, } // current algorithm already adds new tokens to running sequences and - m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output); + m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings); // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); @@ -886,8 +894,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; - sequence_group->notify_handle(num_output_token_to_push); + sequence_group->notify_handle(); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } @@ -926,6 +933,7 @@ void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig void Sampler::clear_request_info(uint64_t request_id) { m_beam_search_info.erase(request_id); m_logit_processors.erase(request_id); + m_stop_strings.erase(request_id); } int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) { diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 0f7876cbf9..981e11560f 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -55,8 +55,11 @@ class Sampler { std::map m_beam_search_info; std::mt19937 rng_engine; + size_t seed = rng_engine.default_seed; // { request_id, logit_processor } std::map m_logit_processors; + // { request_id, { max_encoded_len, { stop_strings }}} + std::map>> m_stop_strings; Tokenizer m_tokenizer; @@ -65,7 +68,11 @@ class Sampler { Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); - void set_seed(size_t seed) { rng_engine.seed(seed); } + void set_seed(size_t new_seed) { + rng_engine.seed(new_seed); + seed = new_seed; + } + size_t get_seed() { return seed; } void clear_request_info(uint64_t request_id); @@ -115,7 +122,7 @@ class Sampler::GroupBeamSearcher { public: explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer); - void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output); + void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::pair>& stop_strings); void finalize(SamplerOutput& sampler_output); std::map get_beam_idxs(); }; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 6755255fe8..220e93c032 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -126,23 +126,28 @@ class Sequence { } } - GenerationOutput get_last_generation_output(size_t token_cnt = 1) { + GenerationOutput get_last_generation_output(size_t token_cnt = 1, size_t num_token_to_ignore = 0) { GenerationOutput output; - OPENVINO_ASSERT(m_generated_ids.size()); - output.score = get_cumulative_log_probs(); + if (token_cnt > 0) { + OPENVINO_ASSERT(m_generated_ids.size()); + output.score = get_cumulative_log_probs(); - auto generated_token_id = get_generated_ids(); - auto generated_log_probs = get_generated_log_probs(); + auto generated_token_id = get_generated_ids(); + auto generated_log_probs = get_generated_log_probs(); - OPENVINO_ASSERT(get_generated_len() >= token_cnt); - auto offset = get_generated_len() - token_cnt; + OPENVINO_ASSERT(get_generated_len() >= token_cnt); + if (get_generated_len() > num_token_to_ignore) { + auto offset = get_generated_len() - token_cnt - num_token_to_ignore; + auto offset_back = get_generated_len() - num_token_to_ignore; - std::vector token_id(generated_token_id.begin() + offset, generated_token_id.end()); - std::vector log_probs(generated_log_probs.begin() + offset, generated_log_probs.end()); + std::vector token_id(generated_token_id.begin() + offset, generated_token_id.begin() + offset_back); + std::vector log_probs(generated_log_probs.begin() + offset, generated_log_probs.begin() + offset_back); - output.generated_ids = token_id; - output.generated_log_probs = log_probs; - output.finish_reason = get_finish_reason(); + output.generated_ids = token_id; + output.generated_log_probs = log_probs; + output.finish_reason = get_finish_reason(); + } + } return output; } @@ -173,8 +178,6 @@ class Sequence { return score; } - - // Each KV block can be uniquely identified by void set_sequence_group_ptr(std::shared_ptr sequence_group) { m_sequence_group = sequence_group; @@ -221,6 +224,8 @@ class SequenceGroup { // flag to enable/disable token generation, e.g. in speculative decoding scenario bool m_is_gen_paused = false; + size_t m_num_streamed_tokens = 0, m_stream_window_size = 0; + SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -332,14 +337,16 @@ class SequenceGroup { std::vector get_finished_sequences() const { std::vector finished_seqs; for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) { - if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory()) { + if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory() || handle_dropped()) { finished_seqs.push_back(m_sequences[seq_id]); } } - // do we need to sort sequences here or sampler can handle it for us? - std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) { - return s1->get_beam_search_score(m_sampling_params) > s2->get_beam_search_score(m_sampling_params); + std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) -> bool { + bool is_beam_search = m_sampling_params.is_beam_search(); + const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_probs(); + const float score_2 = is_beam_search ? s2->get_beam_search_score(m_sampling_params) : s2->get_cumulative_log_probs(); + return score_1 > score_2; }); return finished_seqs; @@ -454,6 +461,10 @@ class SequenceGroup { size_t get_num_tokens_to_validate() { return m_num_validation_tokens; } + + void set_stream_window_size(size_t k) { + m_stream_window_size = k; + } size_t get_num_available_tokens_for_batching() const { OPENVINO_ASSERT(!has_finished(), "Internal error: this function cannot be called on finished sequence group"); @@ -571,7 +582,7 @@ class SequenceGroup { m_generation_stream->set_generation_status(status); } - bool handle_dropped() { + bool handle_dropped() const { return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE; } @@ -601,7 +612,7 @@ class SequenceGroup { for (auto& sequence : m_sequences) { // todo: check seq.is_finished() to generate without several // or is it ok to use padding? - auto output = sequence->get_last_generation_output(token_cnt); + auto output = sequence->get_last_generation_output(token_cnt, m_stream_window_size); if (m_sampling_params.echo && !m_has_echoed) { output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end()); output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end()); @@ -612,24 +623,36 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } - void notify_handle(size_t num_output_token_to_push = 0) { + void notify_handle() { if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { set_generation_status(GenerationStatus::FINISHED); } // For beam search streaming is not available, so we notify only upon finishing - if(m_sampling_params.is_beam_search()) { + if (m_sampling_params.is_beam_search()) { if (has_finished() || out_of_memory()) { push_outputs(); } } else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) { // We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output // (after stop string is detected its tokens are already sent) - if (num_total_seqs() == 1 && - (m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) { - if (num_output_token_to_push) - push_partial_outputs(num_output_token_to_push); + if (num_total_seqs() == 1) { + const auto generated_len = m_sequences.front()->get_generated_len(); + if (has_finished()) { + m_stream_window_size = 0; + } + if (generated_len <= (m_num_streamed_tokens + m_stream_window_size)) { + return; + } + // speculative decoding draft handling + if (generated_len < m_num_streamed_tokens) { + m_num_streamed_tokens = generated_len; + } + OPENVINO_ASSERT(generated_len >= (m_num_streamed_tokens + m_stream_window_size)); + size_t num_output_token_to_push = generated_len - m_num_streamed_tokens - m_stream_window_size; + push_partial_outputs(num_output_token_to_push); + m_num_streamed_tokens += (num_output_token_to_push); } else if (has_finished() || out_of_memory()) { push_outputs(); } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 4a0748b5c0..46b7b106a6 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -46,14 +46,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con draft_scheduler_config = is_scheduler_undefined ? main_scheduler_config : draft_model_desc.scheduler_config; if (is_scheduler_undefined) { // split KV cache to 2 caches for main and draft models - size_t main_model_cache_size = utils::get_kv_cache_size(main_model), - draft_model_cache_size = utils::get_kv_cache_size(draft_model); - auto k = static_cast(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size); + size_t main_model_hidden_size = utils::get_hidden_size(main_model), + draft_model_hidden_size = utils::get_hidden_size(draft_model); + auto k = static_cast(draft_model_hidden_size) / (main_model_hidden_size + draft_model_hidden_size); - size_t main_cache_size = main_scheduler_config.cache_size * (1 - k), + size_t main_cache_size = std::ceil(main_scheduler_config.cache_size * (1.f - k)), draft_cache_size = main_scheduler_config.cache_size - main_cache_size; + OPENVINO_ASSERT(main_cache_size > 0, "KV cache model cache size should be > 0"); if (draft_cache_size == 0) { - main_cache_size -= main_cache_size > 1 ? 1 : 0; + main_cache_size -= (main_cache_size > 1 ? 1 : 0); draft_cache_size = 1; } @@ -63,7 +64,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con ov::AnyMap draft_properties = draft_model_desc.properties == ov::AnyMap{} ? compile_properties : draft_model_desc.properties; - DeviceConfig main_device_config(core, main_scheduler_config, main_device, compile_properties), + DeviceConfig main_device_config(core, main_scheduler_config_updated, main_device, compile_properties), draft_device_config(core, draft_scheduler_config, draft_device, draft_properties); utils::set_kv_cache_type_and_shape(main_model, main_device_config); @@ -82,7 +83,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode m_main_pipeline = std::make_shared(core, main_model, main_model_tokenizer, main_model_desc.generation_config, - main_device_config, main_scheduler_config, main_device, compile_properties, true); + main_device_config, main_scheduler_config_updated, main_device, compile_properties, true); m_draft_pipeline = std::make_shared(core, draft_model, draft_model_tokenizer, draft_model_desc.generation_config, draft_device_config, draft_scheduler_config, draft_device, draft_properties, false); @@ -278,4 +279,4 @@ SpeculativeDecodingMetrics ContinuousBatchingPipeline::SpeculativeDecodingImpl::get_speculative_decoding_metrics() { return m_sd_metrics; }; -} \ No newline at end of file +} diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 314a7ffa4d..5938b55f6c 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -52,4 +52,4 @@ void TextCallbackStreamer::end() { ov::genai::StreamerBase::~StreamerBase() = default; } // namespace genai -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index a03b0deccb..6f0872ad1b 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -25,4 +25,4 @@ class TextCallbackStreamer: public StreamerBase { }; } // namespace genai -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index 642236d32a..ed6fbc0a06 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -194,10 +194,16 @@ class Tokenizer::TokenizerImpl { void setupTokenizer(const std::pair, std::shared_ptr>& models, const ov::AnyMap& properties) { auto [ov_tokenizer, ov_detokenizer] = models; + OPENVINO_ASSERT(ov_tokenizer || ov_detokenizer, "Neither tokenizer nor detokenzier models were provided"); - m_older_than_24_5 = ov_tokenizer->get_rt_info().count("openvino_tokenizers_version") != 1; auto core = get_core_singleton(); std::string device = "CPU"; // only CPU is supported for now + + std::string version_str; + utils::read_rt_info(ov_tokenizer != nullptr ? ov_tokenizer: ov_detokenizer , "openvino_tokenizers_version", version_str); + // Saving IR version was added only in 24.5, so if it's empty, then it's older than 24.5 + m_older_than_24_5 = version_str.empty(); + if (ov_tokenizer) { ov::pass::Manager manager; manager.register_pass(); @@ -230,7 +236,8 @@ class Tokenizer::TokenizerImpl { if (m_tokenizer) { // TODO CVS-150630: Empty strings sporadically can fail, therefore use nonempty string for warmup. encode("non empty string").input_ids; - if (m_detokenizer) + } + if (m_detokenizer) { decode({1, 33, 199, 42, 42}); } @@ -377,6 +384,9 @@ class Tokenizer::TokenizerImpl { } TokenizedInputs encode(std::string prompt, const ov::AnyMap& tokenization_params = {}) { + OPENVINO_ASSERT(m_ireq_queue_tokenizer, "Either openvino_tokenizer.xml was not provided or it was not loaded correctly. " + "Tokenizer::encode is not available"); + CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_tokenizer.get()); set_state_if_necessary(infer_request_guard, tokenization_params); size_t batch_size = 1; @@ -390,6 +400,8 @@ class Tokenizer::TokenizerImpl { } TokenizedInputs encode(std::vector& prompts, const ov::AnyMap& tokenization_params = {}) { + OPENVINO_ASSERT(m_ireq_queue_tokenizer, "Either openvino_tokenizer.xml was not provided or it was not loaded correctly. " + "Tokenizer::encode is not available"); TokenizedInputs unpadded; { CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_tokenizer.get()); diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 9fa14b7f9f..be9fc972dc 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -381,6 +381,14 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se } } +ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front) { + ov::Tensor new_tensor = ov::Tensor{ov::element::i64, {base_tensor.get_shape().at(0), base_tensor.get_shape().at(1) + 1}}; + auto new_tensor_data = new_tensor.data(); + new_tensor_data[0] = add_to_front; + std::copy_n(base_tensor.data(), base_tensor.get_size(), new_tensor_data + 1); + return new_tensor; +} + void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title) { // Specify the name of the environment variable const char* env_var_name = "OPENVINO_LOG_LEVEL"; diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 5342ac427c..57225e60ff 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -28,6 +28,21 @@ enum class GenerationChatInputsType { ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs }; +struct HistoryRemoveManager +{ + size_t num_tokens_to_remove_from_kv_cache = 0; + size_t trusted_history_length = 0; + + bool does_kv_cache_need_to_update() { + return (trusted_history_length > 0 || num_tokens_to_remove_from_kv_cache > 0); + } + + void reset() { + num_tokens_to_remove_from_kv_cache = 0; + trusted_history_length = 0; + } +}; + Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); @@ -104,6 +119,8 @@ size_t get_seq_len_axis(std::shared_ptr model); void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional adapter_controller); +ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front); + void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title); } // namespace utils diff --git a/src/cpp/src/utils/paged_attention_transformations.cpp b/src/cpp/src/utils/paged_attention_transformations.cpp index 53690f770c..16c9556151 100644 --- a/src/cpp/src/utils/paged_attention_transformations.cpp +++ b/src/cpp/src/utils/paged_attention_transformations.cpp @@ -16,7 +16,7 @@ inline ov::PartialShape to_partial_with_dyn_0_dim(const ov::Shape& static_shape) return partial_shape; } -size_t get_kv_cache_size(const std::shared_ptr model) { +size_t get_hidden_size(const std::shared_ptr model) { const auto& parameters = model->get_parameters(); // extract num_kv_heads and head_size size_t kv_caches_inputs_offset = 2; diff --git a/src/cpp/src/utils/paged_attention_transformations.hpp b/src/cpp/src/utils/paged_attention_transformations.hpp index 3bc423d7bc..88ac0876c5 100644 --- a/src/cpp/src/utils/paged_attention_transformations.hpp +++ b/src/cpp/src/utils/paged_attention_transformations.hpp @@ -23,7 +23,7 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control = false); -size_t get_kv_cache_size(const std::shared_ptr model); +size_t get_hidden_size(const std::shared_ptr model); void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index cf77dfce3c..e53be4e1cd 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -10,6 +10,7 @@ #include "utils.hpp" + namespace { constexpr size_t BATCH_SIZE = 1; @@ -40,10 +41,13 @@ class InputsEmbedder::IInputsEmbedder { // Templated chat history std::string m_templated_chat_history; // Tokenized chat history - std::vector m_tokenized_chat_history; - // The number of elements, which need to remove from the end of KV cache - // removed elements will be added to inputs_ids - size_t m_to_remove_from_hist = 0; + std::vector m_tokenized_history; + // Tail of previous output for LM in chat mode is missing in KV cache. + std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; public: virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0; @@ -56,26 +60,34 @@ class InputsEmbedder::IInputsEmbedder { return m_tokenizer; } - std::vector get_tokenized_chat_history() const { - return m_tokenized_chat_history; + std::vector get_tokenized_history() const { + return m_tokenized_history; } - size_t get_amount_to_remove_from_hist() const { - return m_to_remove_from_hist; + size_t get_num_tokens_to_remove_from_hist() const { + return m_kv_history_manager.num_tokens_to_remove_from_kv_cache; } - void update_tokenized_chat_history(std::vector encoded_result) { - std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_chat_history)); - m_to_remove_from_hist = 0; + void update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len) { + if (is_beam_search) { + m_kv_history_manager.trusted_history_length = m_tokenized_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len; + } else { + m_kv_history_manager.reset(); + } + + m_last_disappeared_token = last_disappeared_token; + + std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history)); } virtual void start_chat(const std::string& system_message) { m_is_chat_conversation = true; - m_to_remove_from_hist = 0; - if (!m_tokenized_chat_history.empty()) { + m_kv_history_manager.reset(); + if (!m_tokenized_history.empty()) { m_history.clear(); m_templated_chat_history.clear(); - m_tokenized_chat_history.clear(); + m_tokenized_history.clear(); } if (system_message.empty()) { return; @@ -94,11 +106,11 @@ class InputsEmbedder::IInputsEmbedder { virtual void finish_chat() { m_is_chat_conversation = false; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_history.clear(); m_templated_chat_history.clear(); - m_tokenized_chat_history.clear(); + m_tokenized_history.clear(); } protected: @@ -164,38 +176,55 @@ class InputsEmbedder::IInputsEmbedder { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; - if (!m_tokenized_chat_history.empty()) { + size_t trusted_history_length = 0; + if (!m_tokenized_history.empty()) { std::set stop_tokens = {m_tokenizer.get_eos_token_id()}; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); } - if (m_tokenized_chat_history.empty()) { + if (m_tokenized_history.empty()) { encoded_input_ids = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(), - {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.data() + last_same_hist_token); - encoded_input_ids = new_tensor; + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.data() + trusted_history_length); + encoded_input_ids = ov::Tensor(new_chat_tokens.get_element_type(), + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}); + new_tensor.copy_to(encoded_input_ids); } else { encoded_input_ids = utils::subtract_chat_tokenized_inputs( {new_chat_tokens}, prev_chat_tokens ).input_ids; + + if (m_last_disappeared_token.has_value()) + encoded_input_ids = ov::genai::utils::push_front_inputs(encoded_input_ids, *m_last_disappeared_token); } auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); m_templated_chat_history = std::move(new_templated_chat_history); - m_tokenized_chat_history.clear(); - std::copy(new_chat_tokens.data(), new_chat_tokens.data() + new_chat_tokens.get_size(), - std::back_inserter(m_tokenized_chat_history)); + m_tokenized_history.clear(); + std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history)); } else { auto start_tokenizer_time = std::chrono::steady_clock::now(); encoded_input_ids = m_tokenizer.encode(prompt).input_ids; auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); + m_tokenized_history.clear(); + std::copy_n(encoded_input_ids.data(), encoded_input_ids.get_size(), std::back_inserter(m_tokenized_history)); } + return encoded_input_ids; } @@ -1172,16 +1201,16 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const { return m_impl->get_embedding_model(); } -std::vector InputsEmbedder::get_tokenized_chat_history() const { - return m_impl->get_tokenized_chat_history(); +std::vector InputsEmbedder::get_tokenized_history() const { + return m_impl->get_tokenized_history(); } -void InputsEmbedder::update_tokenized_chat_history(std::vector encoded_result) { - return m_impl->update_tokenized_chat_history(encoded_result); +void InputsEmbedder::update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len) { + return m_impl->update_tokenized_history(encoded_result, last_disappeared_token, is_beam_search, last_answer_len); } -size_t InputsEmbedder::get_amount_to_remove_from_hist() const { - return m_impl->get_amount_to_remove_from_hist(); +size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const { + return m_impl->get_num_tokens_to_remove_from_hist(); } Tokenizer InputsEmbedder::get_tokenizer() const { diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 5c5b9d2b81..1d72b742ab 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -41,16 +41,20 @@ class InputsEmbedder { Tokenizer get_tokenizer() const; // returns tokenized chat history - std::vector get_tokenized_chat_history() const; - // add new results to tokenized chat history - void update_tokenized_chat_history(std::vector encoded_result); + std::vector get_tokenized_history() const; + + // add new results to tokenized history + void update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len); + // returns amount of elements, which need to remove from the end of the KV cache - size_t get_amount_to_remove_from_hist() const; + size_t get_num_tokens_to_remove_from_hist() const; // starts chat and adds optional system_message to chat history void start_chat(const std::string& system_message); + // adds currently generated text to chat history void update_chat_history(const std::string& decoded_results); + // finishes chat and clears a chat history void finish_chat(); private: diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 1ce0cbf210..d625485205 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -17,6 +17,7 @@ #include "utils.hpp" #include "lm_encoding.hpp" + using namespace ov::genai; namespace { @@ -66,6 +67,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { float m_load_time_ms = 0; // Axis num in kv cache from m_language model, which contains information about history len size_t m_kv_cache_seq_length_axis = 2; + // Component for applying sampling to lm outputs + Sampler m_sampler; VLMPipelineImpl( const std::filesystem::path& models_dir, @@ -104,6 +107,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMPipelineImpl( @@ -139,6 +145,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMDecodedResults generate( @@ -160,22 +169,21 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); auto end_get_inputs_embeds = std::chrono::steady_clock::now(); - auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist(); + auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist(); ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt); - Sampler sampler = Sampler(m_tokenizer); - std::vector requests; size_t request_id = 0; size_t block_size = 1; // not used bool enable_prefix_caching = false; - auto tokenized_chat_history = m_inputs_embedder->get_tokenized_chat_history(); size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - to_remove_from_hist; size_t inputs_embeds_size = inputs_embeds.get_shape().at(1); + auto tokenized_history = m_inputs_embedder->get_tokenized_history(); ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size }); - std::fill_n(prompt_ids.data(), prompt_ids.get_size(), 0); + std::fill_n(prompt_ids.data(), prompt_ids.get_size(), m_tokenizer.get_pad_token_id()); + std::copy(tokenized_history.begin(), tokenized_history.end(), prompt_ids.data()); SequenceGroup::Ptr sequence_group = std::make_shared(request_id, prompt_ids, generation_config, block_size, enable_prefix_caching); sequence_group->set_sequence_group_ptr(sequence_group); @@ -195,8 +203,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { }, }, streamer); - OPENVINO_ASSERT((generation_config.is_greedy_decoding() || generation_config.is_multinomial() || !streamer_ptr), - "Currently streaming is possible only for greedy or multinomial decoding"); + OPENVINO_ASSERT(streamer_ptr == nullptr || generation_config.num_return_sequences == 1 && + (generation_config.is_greedy_decoding() || generation_config.is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, { 1, history_size + inputs_embeds_size }}; std::fill_n(new_atten_mask.data(), new_atten_mask.get_size(), 1); @@ -204,10 +213,14 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }}; std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size); + if (m_sampler.get_seed() != generation_config.rng_seed) { + m_sampler.set_seed(generation_config.rng_seed); + } + ov::genai::EncodedResults encoded_result; - int32_t m_selected_beam = 0; - std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests, - position_ids, m_embedding, std::nullopt); + std::optional last_disappeared_token; + std::tie(encoded_result, last_disappeared_token) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, + position_ids, m_embedding); auto decode_start_time = std::chrono::steady_clock::now(); VLMDecodedResults decoded; @@ -217,6 +230,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } auto decode_end_time = std::chrono::steady_clock::now(); + m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], last_disappeared_token, generation_config.is_beam_search(), + m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size)); + std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) { m_inputs_embedder->update_chat_history(decoded_results); @@ -243,8 +259,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { decoded.perf_metrics.m_evaluated = false; decoded.perf_metrics.evaluate_statistics(generate_start_time); - m_inputs_embedder->update_tokenized_chat_history(encoded_result.tokens[0]); - return decoded; } diff --git a/src/cpp/src/whisper/context_tokens.cpp b/src/cpp/src/whisper/context_tokens.cpp new file mode 100644 index 0000000000..75ee442551 --- /dev/null +++ b/src/cpp/src/whisper/context_tokens.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "context_tokens.hpp" + +namespace { +std::pair, float> tokenize(std::string&& text, + const ov::genai::WhisperGenerationConfig& config, + ov::genai::Tokenizer& tokenizer) { + if (text.empty()) { + return {{}, 0.0f}; + } + + auto start_time = std::chrono::steady_clock::now(); + auto encoded = tokenizer.encode(text, ov::genai::add_special_tokens(false)); + auto duration = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start_time); + + auto input_ids = encoded.input_ids; + auto input_ids_data = input_ids.data(); + + std::vector prompt_tokens; + prompt_tokens.reserve(input_ids.get_size()); + + // even with ov::genai::add_special_tokens(false) tokenizer adds next special tokens. Ticket: 159569 + std::set special_tokens{config.decoder_start_token_id, config.eos_token_id, config.no_timestamps_token_id}; + + for (size_t i = 0; i < input_ids.get_size(); i++) { + if (special_tokens.count(input_ids_data[i])) { + continue; + } + + prompt_tokens.emplace_back(input_ids_data[i]); + } + + return {prompt_tokens, duration}; +} +} // namespace + +namespace ov { +namespace genai { + +std::pair prepare_context_tokens(const WhisperGenerationConfig& config, + Tokenizer& tokenizer) { + WhisperContextTokens context_tokens; + float duration = 0.0f; + + if (config.initial_prompt.has_value()) { + auto [initial_prompt_tokens, initial_prompt_duration] = + tokenize(" " + *config.initial_prompt, config, tokenizer); + context_tokens.initial_prompt = std::move(initial_prompt_tokens); + duration += initial_prompt_duration; + } + + if (config.hotwords.has_value()) { + auto [hotwords_tokens, hotwords_duration] = tokenize(" " + *config.hotwords, config, tokenizer); + context_tokens.hotwords = std::move(hotwords_tokens); + duration += hotwords_duration; + } + + return {context_tokens, duration}; +} + +std::vector get_prompt_tokens(const WhisperContextTokens& context_tokens, + const WhisperGenerationConfig& config, + size_t chunk_offset) { + bool should_add_initial_prompt = !context_tokens.initial_prompt.empty() && chunk_offset == 0; + bool should_add_hotwords = !context_tokens.hotwords.empty(); + + if (!should_add_initial_prompt && !should_add_hotwords) { + return {}; + } + + std::vector prompt_tokens{config.prev_sot_token_id}; + + if (should_add_initial_prompt) { + prompt_tokens.insert(prompt_tokens.end(), + context_tokens.initial_prompt.begin(), + context_tokens.initial_prompt.end()); + } + + if (should_add_hotwords) { + prompt_tokens.insert(prompt_tokens.end(), context_tokens.hotwords.begin(), context_tokens.hotwords.end()); + } + + return prompt_tokens; +} + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/whisper/context_tokens.hpp b/src/cpp/src/whisper/context_tokens.hpp new file mode 100644 index 0000000000..0042ba8136 --- /dev/null +++ b/src/cpp/src/whisper/context_tokens.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/genai/perf_metrics.hpp" +#include "openvino/genai/whisper_generation_config.hpp" + +namespace ov { +namespace genai { + +struct WhisperContextTokens { + std::vector initial_prompt; + std::vector hotwords; +}; + +std::pair prepare_context_tokens(const WhisperGenerationConfig& config, + Tokenizer& tokenizer); + +std::vector get_prompt_tokens(const WhisperContextTokens& context_tokens, + const WhisperGenerationConfig& config, + size_t chunk_offset); + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/whisper/whisper.cpp b/src/cpp/src/whisper/whisper.cpp index 355ccc619b..9d6aa698ce 100644 --- a/src/cpp/src/whisper/whisper.cpp +++ b/src/cpp/src/whisper/whisper.cpp @@ -8,6 +8,7 @@ #include #include +#include "context_tokens.hpp" #include "logit_processor.hpp" #include "openvino/genai/perf_metrics.hpp" #include "openvino/genai/whisper_generation_config.hpp" @@ -175,11 +176,11 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, return output_token; } -std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, - ov::InferRequest decoder, - const ov::genai::WhisperGenerationConfig& config, - const bool return_timestamps, - ov::genai::RawPerfMetrics& raw_metrics) { +std::vector prepare_init_tokens(ov::Tensor& encoder_hidden_state, + ov::InferRequest decoder, + const ov::genai::WhisperGenerationConfig& config, + const bool return_timestamps, + ov::genai::RawPerfMetrics& raw_metrics) { if (!config.is_multilingual) { if (return_timestamps) { return std::vector{config.decoder_start_token_id}; @@ -290,6 +291,7 @@ namespace genai { WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config, const ov::genai::WhisperConfig& model_config, + const WhisperContextTokens& context_tokens, const RawSpeechInput& raw_speech, ov::genai::WhisperInitializedModels& models, WhisperFeatureExtractor& feature_extractor, @@ -313,7 +315,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& // long-form audio processing requires timestamps to be enabled const bool return_timestamps = config.return_timestamps || !is_shortform; - std::vector init_ids; + std::vector init_tokens; std::vector& output_tokens = result.output_tokens; std::vector segments; @@ -335,14 +337,18 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& raw_metrics); // prepare init_ids just once for whole input - if (init_ids.empty()) { - init_ids = prepare_init_ids(hidden_state_tensor, models.decoder, config, return_timestamps, raw_metrics); + if (init_tokens.empty()) { + init_tokens = + prepare_init_tokens(hidden_state_tensor, models.decoder, config, return_timestamps, raw_metrics); } + std::vector chunk_init_tokens = ov::genai::get_prompt_tokens(context_tokens, config, chunk_offset); + chunk_init_tokens.insert(chunk_init_tokens.end(), init_tokens.begin(), init_tokens.end()); + auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, config, models, - init_ids, + chunk_init_tokens, max_new_tokens - output_tokens.size(), return_timestamps, raw_metrics, diff --git a/src/cpp/src/whisper/whisper.hpp b/src/cpp/src/whisper/whisper.hpp index 4904edf925..81f559db9f 100644 --- a/src/cpp/src/whisper/whisper.hpp +++ b/src/cpp/src/whisper/whisper.hpp @@ -5,6 +5,7 @@ #include +#include "context_tokens.hpp" #include "openvino/genai/whisper_generation_config.hpp" #include "openvino/genai/whisper_pipeline.hpp" #include "whisper_config.hpp" @@ -28,6 +29,7 @@ struct WhisperGenerateResult { WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config, const ov::genai::WhisperConfig& model_config, + const WhisperContextTokens& context_tokens, const ov::genai::RawSpeechInput& raw_speech, ov::genai::WhisperInitializedModels& models, ov::genai::WhisperFeatureExtractor& feature_extractor, diff --git a/src/cpp/src/whisper_generation_config.cpp b/src/cpp/src/whisper_generation_config.cpp index 0fba4e962f..beb663caaf 100644 --- a/src/cpp/src/whisper_generation_config.cpp +++ b/src/cpp/src/whisper_generation_config.cpp @@ -8,8 +8,8 @@ #include #include -#include "utils.hpp" #include "json_utils.hpp" +#include "utils.hpp" namespace ov { namespace genai { @@ -31,6 +31,7 @@ WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& js read_json_param(data, "pad_token_id", pad_token_id); read_json_param(data, "no_timestamps_token_id", no_timestamps_token_id); read_json_param(data, "max_initial_timestamp_index", max_initial_timestamp_index); + read_json_param(data, "prev_sot_token_id", prev_sot_token_id); read_json_param(data, "is_multilingual", is_multilingual); if (is_multilingual) { @@ -73,6 +74,8 @@ void WhisperGenerationConfig::update_generation_config(const ov::AnyMap& config_ read_anymap_param(config_map, "lang_to_id", lang_to_id); read_anymap_param(config_map, "task", task); read_anymap_param(config_map, "return_timestamps", return_timestamps); + read_anymap_param(config_map, "initial_prompt", initial_prompt); + read_anymap_param(config_map, "hotwords", hotwords); } size_t WhisperGenerationConfig::get_max_new_tokens(size_t prompt_length) const { diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index d472a20238..f0fb34cdf6 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -9,6 +9,7 @@ #include #include "utils.hpp" +#include "whisper/context_tokens.hpp" #include "whisper/streamer.hpp" #include "whisper/whisper.hpp" #include "whisper/whisper_config.hpp" @@ -91,8 +92,11 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi streamer_ptr = std::make_shared(m_tokenizer, *callback); } + auto [context_tokens, tokenization_duration_microseconds] = prepare_context_tokens(config, m_tokenizer); + auto generate_result = ov::genai::whisper_generate(config, m_model_config, + context_tokens, raw_speech_input, m_models, m_feature_extractor, @@ -102,6 +106,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back( PerfMetrics::get_microsec(std::chrono::steady_clock::now() - decode_start_time)); + result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(tokenization_duration_microseconds); + result.perf_metrics = generate_result.perf_metrics; auto& segments = generate_result.segments; diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 136819fa01..dc26789846 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -579,6 +579,9 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( WhisperGenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; config.validate(); + OPENVINO_ASSERT(!config.initial_prompt.has_value(), "'initial_prompt' parameter is not supported on NPU device."); + OPENVINO_ASSERT(!config.hotwords.has_value(), "'hotwords' parameter is not supported on NPU device."); + std::shared_ptr streamer_ptr; if (auto streamer_obj = std::get_if(&streamer)) { streamer_ptr = nullptr; diff --git a/src/docs/SUPPORTED_MODELS.md b/src/docs/SUPPORTED_MODELS.md index 8c922ee644..9762874596 100644 --- a/src/docs/SUPPORTED_MODELS.md +++ b/src/docs/SUPPORTED_MODELS.md @@ -217,6 +217,7 @@ The pipeline can work with other similar topologies produced by `optimum-intel` diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 524ff0f921..3d27b23052 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1343,15 +1343,18 @@ class Scheduler: FLOW_MATCH_EULER_DISCRETE PNDM + + EULER_ANCESTRAL_DISCRETE """ AUTO: typing.ClassVar[Scheduler.Type] # value = DDIM: typing.ClassVar[Scheduler.Type] # value = + EULER_ANCESTRAL_DISCRETE: typing.ClassVar[Scheduler.Type] # value = EULER_DISCRETE: typing.ClassVar[Scheduler.Type] # value = FLOW_MATCH_EULER_DISCRETE: typing.ClassVar[Scheduler.Type] # value = LCM: typing.ClassVar[Scheduler.Type] # value = LMS_DISCRETE: typing.ClassVar[Scheduler.Type] # value = PNDM: typing.ClassVar[Scheduler.Type] # value = - __members__: typing.ClassVar[dict[str, Scheduler.Type]] # value = {'AUTO': , 'LCM': , 'LMS_DISCRETE': , 'DDIM': , 'EULER_DISCRETE': , 'FLOW_MATCH_EULER_DISCRETE': , 'PNDM': } + __members__: typing.ClassVar[dict[str, Scheduler.Type]] # value = {'AUTO': , 'LCM': , 'LMS_DISCRETE': , 'DDIM': , 'EULER_DISCRETE': , 'FLOW_MATCH_EULER_DISCRETE': , 'PNDM': , 'EULER_ANCESTRAL_DISCRETE': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: @@ -1945,6 +1948,9 @@ class WhisperGenerationConfig: :param no_timestamps_token_id: No timestamps token id. :type no_timestamps_token_id: int + :param prev_sot_token_id: Corresponds to the ”<|startofprev|>” token. + :type prev_sot_token_id: int + :param is_multilingual: :type is_multilingual: bool @@ -1973,10 +1979,34 @@ class WhisperGenerationConfig: then it means the model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. Note that a segment of text refers to a sequence of one or more words, rather than individual words. :type return_timestamps: bool + + :param initial_prompt: Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing + window. Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type initial_prompt: Optional[str] + + :param hotwords: Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows. + Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type hotwords: Optional[str] """ begin_suppress_tokens: list[int] decoder_start_token_id: int eos_token_id: int + hotwords: str | None + initial_prompt: str | None is_multilingual: bool lang_to_id: dict[str, int] language: str | None @@ -1985,6 +2015,7 @@ class WhisperGenerationConfig: max_new_tokens: int no_timestamps_token_id: int pad_token_id: int + prev_sot_token_id: int return_timestamps: bool suppress_tokens: list[int] task: str | None @@ -2077,6 +2108,9 @@ class WhisperPipeline: :param no_timestamps_token_id: No timestamps token id. :type no_timestamps_token_id: int + :param prev_sot_token_id: Corresponds to the ”<|startofprev|>” token. + :type prev_sot_token_id: int + :param is_multilingual: :type is_multilingual: bool @@ -2105,6 +2139,28 @@ class WhisperPipeline: then it means the model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. Note that a segment of text refers to a sequence of one or more words, rather than individual words. :type return_timestamps: bool + + :param initial_prompt: Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing + window. Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type initial_prompt: Optional[str] + + :param hotwords: Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows. + Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type hotwords: Optional[str] """ def get_generation_config(self) -> WhisperGenerationConfig: ... diff --git a/src/python/py_image_generation_pipelines.cpp b/src/python/py_image_generation_pipelines.cpp index f5347c279d..311f3f3760 100644 --- a/src/python/py_image_generation_pipelines.cpp +++ b/src/python/py_image_generation_pipelines.cpp @@ -198,7 +198,8 @@ void init_image_generation_pipelines(py::module_& m) { .value("DDIM", ov::genai::Scheduler::Type::DDIM) .value("EULER_DISCRETE", ov::genai::Scheduler::Type::EULER_DISCRETE) .value("FLOW_MATCH_EULER_DISCRETE", ov::genai::Scheduler::Type::FLOW_MATCH_EULER_DISCRETE) - .value("PNDM", ov::genai::Scheduler::Type::PNDM); + .value("PNDM", ov::genai::Scheduler::Type::PNDM) + .value("EULER_ANCESTRAL_DISCRETE", ov::genai::Scheduler::Type::EULER_ANCESTRAL_DISCRETE); image_generation_scheduler.def_static("from_config", &ov::genai::Scheduler::from_config, py::arg("scheduler_config_path"), diff --git a/src/python/py_whisper_pipeline.cpp b/src/python/py_whisper_pipeline.cpp index 49152c03f4..cd42dcf58d 100644 --- a/src/python/py_whisper_pipeline.cpp +++ b/src/python/py_whisper_pipeline.cpp @@ -103,6 +103,9 @@ auto whisper_generation_config_docstring = R"( :param no_timestamps_token_id: No timestamps token id. :type no_timestamps_token_id: int + :param prev_sot_token_id: Corresponds to the ”<|startofprev|>” token. + :type prev_sot_token_id: int + :param is_multilingual: :type is_multilingual: bool @@ -131,6 +134,28 @@ auto whisper_generation_config_docstring = R"( then it means the model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. Note that a segment of text refers to a sequence of one or more words, rather than individual words. :type return_timestamps: bool + + :param initial_prompt: Initial prompt tokens passed as a previous transcription (after `<|startofprev|>` token) to the first processing + window. Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::initial_prompt("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type initial_prompt: Optional[str] + + :param hotwords: Hotwords tokens passed as a previous transcription (after `<|startofprev|>` token) to the all processing windows. + Can be used to steer the model to use particular spellings or styles. + + Example: + auto result = pipeline.generate(raw_speech); + // He has gone and gone for good answered Paul Icrom who... + + auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); + // He has gone and gone for good answered Polychrome who... + :type hotwords: Optional[str] )"; auto streamer_base_docstring = R"( @@ -262,11 +287,14 @@ void init_whisper_pipeline(py::module_& m) { .def_readwrite("transcribe_token_id", &WhisperGenerationConfig::transcribe_token_id) .def_readwrite("max_initial_timestamp_index", &WhisperGenerationConfig::max_initial_timestamp_index) .def_readwrite("no_timestamps_token_id", &WhisperGenerationConfig::no_timestamps_token_id) + .def_readwrite("prev_sot_token_id", &WhisperGenerationConfig::prev_sot_token_id) .def_readwrite("is_multilingual", &WhisperGenerationConfig::is_multilingual) .def_readwrite("language", &WhisperGenerationConfig::language) .def_readwrite("lang_to_id", &WhisperGenerationConfig::lang_to_id) .def_readwrite("task", &WhisperGenerationConfig::task) .def_readwrite("return_timestamps", &WhisperGenerationConfig::return_timestamps) + .def_readwrite("initial_prompt", &WhisperGenerationConfig::initial_prompt) + .def_readwrite("hotwords", &WhisperGenerationConfig::hotwords) .def("set_eos_token_id", &WhisperGenerationConfig::set_eos_token_id, py::arg("tokenizer_eos_token_id")); py::class_(m, "WhisperRawPerfMetrics", raw_perf_metrics_docstring) diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 50ee452f5c..163a00192e 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -125,6 +125,34 @@ def get_beam_search_with_multiple_stop_strings_no_match() -> GenerationConfig: generation_config.include_stop_str_in_output = True return generation_config +def get_greedy_stop_strings_exclude_from_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines" } + generation_config.include_stop_str_in_output = False + return generation_config + +def get_greedy_stop_strings_include_to_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines" } + generation_config.include_stop_str_in_output = True + return generation_config + +def get_greedy_n_stop_strings_exclude_from_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines", "manage" } + generation_config.include_stop_str_in_output = False + return generation_config + +def get_greedy_n_stop_strings_include_to_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines", "manage" } + generation_config.include_stop_str_in_output = True + return generation_config + def get_multinomial_temperature() -> GenerationConfig: generation_config = GenerationConfig() generation_config.do_sample = True @@ -359,9 +387,14 @@ def compare_results(hf_result: GenerationResult, ov_result: GenerationResult, ge # Note, that for fp32 / fp16 models scores are different less than 0.001 assert abs(hf_score - ov_score) < 0.02 - assert len(hf_result.m_generation_ids) == len(ov_result.m_generation_ids) - for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): - assert hf_text == ov_text + if not generation_config.include_stop_str_in_output and len(generation_config.stop_strings) > 0: + assert len(hf_result.m_generation_ids) >= len(ov_result.m_generation_ids) + for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): + assert ov_text in hf_text + else: + assert len(hf_result.m_generation_ids) == len(ov_result.m_generation_ids) + for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): + assert hf_text == ov_text def save_ov_model_from_optimum(model, hf_tokenizer, models_path: Path): model.save_pretrained(models_path) diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index 3dac3f8b00..bc5324b211 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a numpy<2.0.0; sys_platform == 'darwin' onnx==1.17.0 pytest diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py index 9260e671d6..d9661e538b 100644 --- a/tests/python_tests/test_chat_generate_api.py +++ b/tests/python_tests/test_chat_generate_api.py @@ -187,10 +187,13 @@ def test_set_chat_template(): model_descr = get_chat_models_list()[0] model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}") + config = ov_genai.GenerationConfig() + config.max_new_tokens = 1 + config.do_sample = False pipe.start_chat() - generated = pipe.generate("a", max_new_tokens=1) + generated = pipe.generate("a", config) pipe.finish_chat() - reference = pipe.generate("a", max_new_tokens=1) + reference = pipe.generate("a", config) assert generated == reference prompts = [ diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py index 9aa6931d85..d5df28bfd6 100644 --- a/tests/python_tests/test_sampling.py +++ b/tests/python_tests/test_sampling.py @@ -21,6 +21,8 @@ get_beam_search, get_beam_search_min_and_max_tokens, get_beam_search_with_single_stop_string, \ get_beam_search_with_multiple_stop_strings, get_beam_search_with_multiple_stop_strings_no_match, get_multinomial_max_and_min_token, \ get_multinomial_temperature_and_frequence_penalty, get_multinomial_temperature_and_presence_penalty, \ + get_greedy_stop_strings_exclude_from_output, get_greedy_stop_strings_include_to_output, \ + get_greedy_n_stop_strings_exclude_from_output, get_greedy_n_stop_strings_include_to_output, \ generate_and_compare_with_hf, get_multinomial_temperature_and_repetition_penalty, get_scheduler_config, \ run_continuous_batching @@ -77,7 +79,9 @@ def test_eos_greedy(tmp_path): @pytest.mark.precommit @pytest.mark.parametrize("generation_config", [get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_single_stop_string(), get_greedy_with_multiple_stop_strings(), get_greedy_with_multiple_stop_strings_no_match(), - get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), ], + get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), + get_greedy_stop_strings_exclude_from_output(), get_greedy_stop_strings_include_to_output(), + get_greedy_n_stop_strings_exclude_from_output(), get_greedy_n_stop_strings_include_to_output() ], ids=[ "greedy", "greedy_with_min_and_max_tokens", @@ -88,6 +92,10 @@ def test_eos_greedy(tmp_path): "beam", "beam_search_min_and_max_tokens", "beam_search_with_multiple_stop_strings_no_match", + "get_greedy_stop_strings_exclude_from_output", + "get_greedy_stop_strings_include_to_output", + "get_greedy_n_stop_strings_exclude_from_output", + "get_greedy_n_stop_strings_include_to_output" ]) def test_individual_generation_configs_deterministic(tmp_path, generation_config): prompts = [ diff --git a/tests/python_tests/test_whisper_generate_api.py b/tests/python_tests/test_whisper_generate_api.py index 5a68dd98b6..1450ef1f2e 100644 --- a/tests/python_tests/test_whisper_generate_api.py +++ b/tests/python_tests/test_whisper_generate_api.py @@ -25,7 +25,9 @@ def run_gc_after_test(): yield gc.collect() -@functools.lru_cache(1) +# used whisper models are relatively small +# cache them in memory to speedup tests +@functools.lru_cache(3) def read_whisper_model(params, **tokenizer_kwargs): model_id, path = params @@ -568,6 +570,31 @@ def test_longform_audio(model_descr, test_sample): assert genai_result.chunks == None +@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) +@pytest.mark.parametrize( + "test_sample", + get_samples_from_dataset(length=1), +) +@pytest.mark.precommit +def test_initial_prompt_hotwords(model_descr, test_sample): + model_id, path, opt_pipe, pipe = read_whisper_model(model_descr) + + result = pipe.generate(test_sample) + + assert "Joel Keaton" in result.texts[0] + assert "Joel Kyton" not in result.texts[0] + + result = pipe.generate(test_sample, initial_prompt="Joel Kyton") + + assert "Joel Keaton" not in result.texts[0] + assert "Joel Kyton" in result.texts[0] + + result = pipe.generate(test_sample, hotwords="Joel Kyton") + + assert "Joel Keaton" not in result.texts[0] + assert "Joel Kyton" in result.texts[0] + + @pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) @pytest.mark.parametrize( "test_sample", diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index c3df84925b..316c9d0b89 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -421,7 +421,7 @@ def get_vae_decoder_step_count(self): scheduler_type = data.get("scheduler", ["", ""])[1] if (scheduler_type not in ["LCMScheduler", "DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler", "EulerDiscreteScheduler", - "FlowMatchEulerDiscreteScheduler"]): + "FlowMatchEulerDiscreteScheduler", "EulerAncestralDiscreteScheduler"]): scheduler = openvino_genai.Scheduler.from_config(model_path / "scheduler/scheduler_config.json", openvino_genai.Scheduler.Type.DDIM) log.warning(f'Type of scheduler {scheduler_type} is unsupported. Please, be aware that it will be replaced to DDIMScheduler') diff --git a/tools/llm_bench/requirements.txt b/tools/llm_bench/requirements.txt index f5f4a3fdeb..acbc668c52 100644 --- a/tools/llm_bench/requirements.txt +++ b/tools/llm_bench/requirements.txt @@ -10,7 +10,7 @@ torch transformers>=4.40.0 diffusers>=0.22.0 #optimum is in dependency list of optimum-intel -git+https://github.com/huggingface/optimum-intel.git@main#egg=optimum-intel +git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a#egg=optimum-intel git+https://github.com/openvinotoolkit/nncf.git@develop#egg=nncf packaging psutil diff --git a/tools/llm_bench/task/speech_to_text_generation.py b/tools/llm_bench/task/speech_to_text_generation.py index f1e7ac54a0..15a47a8b6a 100644 --- a/tools/llm_bench/task/speech_to_text_generation.py +++ b/tools/llm_bench/task/speech_to_text_generation.py @@ -57,7 +57,7 @@ def run_speech_2_txt_generation(input_param, args, md5_list, iter_data_list): - np.array(perf_metrics.raw_metrics.m_new_token_times[:-1]) ).tolist() tm_list = (np.array([first_token_time] + second_tokens_durations) / 1000).tolist() - tm_infer_list = None + tm_infer_list = (np.array(perf_metrics.raw_metrics.token_infer_durations) / 1000 / 1000).tolist() result_text = result_text.texts[0] else: start = time.perf_counter() diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 3f5b5ed301..485de94996 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -302,6 +302,7 @@ def token_printer(): ).tolist() tm_list = np.array([first_token_time] + second_tokens_durations) / 1000 + inference_durations = (np.array(perf_metrics.raw_metrics.token_infer_durations) / 1000 / 1000).tolist() log.debug('latency of all tokens:') [log.debug('[{}]{:.4f}'.format(idx, tm)) for idx, tm in enumerate(tm_list)] iter_data = gen_output_data.gen_iterate_data( @@ -323,7 +324,7 @@ def token_printer(): num, iter_data, tm_list.tolist(), - None, + inference_durations.tolist(), warm_up=(num == 0), max_rss_mem=max_rss_mem_consumption, max_shared_mem=max_shared_mem_consumption, diff --git a/tools/llm_bench/task/visual_language_generation.py b/tools/llm_bench/task/visual_language_generation.py index c4144366b4..068ae0cf60 100644 --- a/tools/llm_bench/task/visual_language_generation.py +++ b/tools/llm_bench/task/visual_language_generation.py @@ -268,11 +268,12 @@ def run_visual_language_generation_genai( mm_embeddings_preparation_time=perf_metrics.get_prepare_embeddings_duration().mean ) iter_data_list.append(iter_data) + inference_durations = np.array(perf_metrics.raw_metrics.token_infer_durations) / 1000 / 1000 metrics_print.print_metrics( num, iter_data, tm_list.tolist(), - None, + inference_durations.tolist(), warm_up=(num == 0), max_rss_mem=max_rss_mem_consumption, max_shared_mem=max_shared_mem_consumption, diff --git a/tools/who_what_benchmark/whowhatbench/wwb.py b/tools/who_what_benchmark/whowhatbench/wwb.py index 026a6cc69b..04813f5fd8 100644 --- a/tools/who_what_benchmark/whowhatbench/wwb.py +++ b/tools/who_what_benchmark/whowhatbench/wwb.py @@ -1,7 +1,3 @@ -from .utils import patch_diffusers - -patch_diffusers() - import argparse import difflib import numpy as np