diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/llm_bench-python.yml index 1999bafcfe..8356805e19 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/llm_bench-python.yml @@ -151,7 +151,7 @@ jobs: rm -rf ./ov_models/internvl2-1B - name: WWB Tests run: | - pip install git+https://github.com/huggingface/optimum-intel.git + pip install git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} python -m pytest -v ${{ env.WWB_PATH }}/tests stateful: @@ -190,7 +190,7 @@ jobs: - name: WWB Tests run: | pip install pytest - pip install git+https://github.com/huggingface/optimum-intel.git + pip install git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} python -m pytest -v ${{ env.WWB_PATH }}/tests diff --git a/samples/export-requirements.txt b/samples/export-requirements.txt index 797b680b9a..d75fdbacee 100644 --- a/samples/export-requirements.txt +++ b/samples/export-requirements.txt @@ -2,7 +2,7 @@ --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly openvino-tokenizers~=2025.0.0.0.dev -optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a numpy<2.0.0; sys_platform == 'darwin' einops==0.8.0 # For Qwen transformers_stream_generator==0.0.5 # For Qwen diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 6e7e982a4c..e1ffd062de 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -22,7 +22,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( m_tokenizer = tokenizer; m_generation_config = generation_config; m_is_validation_mode_enabled = is_validation_mode_enabled; - + ov::Core core; auto [core_properties, compile_properties] = utils::split_core_compile_config(properties); @@ -255,18 +255,6 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generations; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id])); - } - - std::vector results; - results.reserve(m_awaiting_requests.size()); - auto drop_requests = [&] () { for (const std::shared_ptr request : m_requests) { for (const auto& sequence: request->get_sequences()) { @@ -279,25 +267,40 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generations; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id])); + } + auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished + bool continue_generation = true; while (has_non_finished_requests() && continue_generation) { try { step(); } catch (...) { - drop_requests(); + drop_requests(); // remove all requests from pipeline state in case of exception throw; } - if (streamer_ptr && generations.at(0)->can_read()) { - std::unordered_map token = generations.at(0).get()->back(); + + auto & generation = generations.at(0); + if (streamer_ptr && generation->can_read()) { + std::unordered_map token = generation->back(); for (const auto& gen_token : token.begin()->second.generated_ids) { - if (!streamer_ptr->put(gen_token)) { + continue_generation = !streamer_ptr->put(gen_token); + if (!continue_generation) { + generation->drop(); break; } } } } - if (streamer_ptr) { + if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } @@ -307,16 +310,32 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector results; + results.reserve(all_requests.size()); + + for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { + const auto& request = all_requests[request_id]; + auto sampling_params = request->get_sampling_parameters(); + const auto& sequences = request->get_finished_sequences(); + size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); + EncodedGenerationResult result; - result.m_request_id = 1; - std::vector generation_outputs = generation->read_all(); - for (const auto& generation_output : generation_outputs) { - result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); - result.m_scores.push_back(generation_output.score); + result.m_request_id = request_id; + result.m_generation_ids.resize(num_outputs); + result.m_scores.resize(num_outputs); + + for (size_t i = 0; i < num_outputs; ++i) { + const auto & sequence = sequences[i]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs(); + const auto & generated_ids = sequence->get_generated_ids(); + + if (sampling_params.echo) + result.m_generation_ids[i] = request->get_prompt_ids(); + std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; } - result.m_status = generation->get_status(); + + result.m_status = generations[request_id]->get_status(); results.push_back(std::move(result)); } @@ -408,7 +427,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) { SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; // requests not scheduled, in decoding phase or not echoing are not processed - if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() || + if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() || !sequence_group->get_sampling_parameters().echo) continue; @@ -421,10 +440,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens(); OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len()); - + // if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion) - // otherwise we include it as it will be used in the next part of the prompt - int exclude_last_logprob = 1; + // otherwise we include it as it will be used in the next part of the prompt + int exclude_last_logprob = 1; if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len()) exclude_last_logprob = 0; @@ -435,7 +454,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1; token_logits_offset < actual_seq_len - exclude_last_logprob; token_logits_offset++, token_id_offset++) { - + const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size); int64_t token_id = sequence_group->get_prompt_ids()[token_id_offset]; float token_logit = token_logits[token_id]; diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index a1dd467523..0f10a85a86 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -17,7 +17,7 @@ GenerationStatus GenerationHandleImpl::get_status() { } bool GenerationHandleImpl::can_read() { - return !is_dropped() && m_generation_stream->can_read(); + return !is_dropped() && m_generation_stream->can_read(); } bool GenerationHandleImpl::is_dropped() { diff --git a/src/cpp/src/generation_stream.hpp b/src/cpp/src/generation_stream.hpp index 4d41f160e4..518699ba36 100644 --- a/src/cpp/src/generation_stream.hpp +++ b/src/cpp/src/generation_stream.hpp @@ -14,8 +14,6 @@ class GenerationStream { GenerationStatus m_status = GenerationStatus::RUNNING; SynchronizedQueue m_output_queue; - std::vector last_sequence_ids; - public: using Ptr = std::shared_ptr; @@ -30,10 +28,11 @@ class GenerationStream { m_output_queue.push(std::move(outputs)); } - // Retrieving vector of pairs as we can generate multiple outputs for a single prompt + // Retrieving vector of pairs as we can generate multiple outputs for a single prompt GenerationOutputs back() { return m_output_queue.back(); } + GenerationOutputs read() { return m_output_queue.pull(); } diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp deleted file mode 100644 index a0262c0dc8..0000000000 --- a/src/cpp/src/group_beam_searcher.cpp +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "openvino/genai/llm_pipeline.hpp" -#include "utils.hpp" -#include "lm_encoding.hpp" - -namespace { - -// Modified Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurrence in haystack -std::vector kmp_search(const std::vector& haystack, const std::vector& needle) { - if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token - return {haystack.begin(), haystack.end()}; - } - std::vector partial_match_table(needle.size() + 1, -1); - int cnd = 0; - for (size_t pos = 1; pos < needle.size(); ++pos) { - if (needle.at(pos) == needle.at(size_t(cnd))) { - partial_match_table.at(pos) = partial_match_table.at(size_t(cnd)); - } else { - partial_match_table.at(pos) = cnd; - while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) { - cnd = partial_match_table.at(size_t(cnd)); - } - } - ++cnd; - } - partial_match_table.back() = cnd; - std::vector res; - size_t haystack_id = 0; - int needle_id = 0; - while (haystack_id < haystack.size() - 1) { - if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) { - ++haystack_id; - ++needle_id; - if (needle_id == int(needle.size())) { - res.push_back(haystack.at(haystack_id)); - needle_id = partial_match_table.at(size_t(needle_id)); - } - } else { - needle_id = partial_match_table.at(size_t(needle_id)); - if (needle_id < 0) { - ++haystack_id; - ++needle_id; - } - } - } - return res; -} - -struct Token { - float log_prob; - int64_t idx; -}; - -std::vector log_softmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape().at(0) <= batch_idx) { - throw std::runtime_error("logits batch size doesn't match the number of beams"); - } - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; - size_t sequence_offset = (logits.get_shape().at(1) - 1) * vocab_size; - const float* beam_logits = logits.data() + batch_offset + sequence_offset; - float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size); - float log_sum = std::log( - std::accumulate(beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) { - return accumulated + std::exp(to_add - max_logit); - })); - std::vector tokens; - tokens.reserve(vocab_size); - for (size_t idx = 0; idx < vocab_size; ++idx) { - tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)}); - } - return tokens; -} - -struct Beam { - float score = -std::numeric_limits::infinity(); // The bigger, the better - std::vector tokens; - size_t global_beam_idx = 0; -}; - -bool greater(const Beam& left, const Beam& right) { - return left.score > right.score; -} - -struct Parameters { - std::vector> prompts; - int64_t eos_token_id; - size_t n_groups = 3; - size_t group_size = 5; - float diversity_penalty = 1.0; - size_t max_new_tokens = 20; - ov::genai::StopCriteria stop_criteria = ov::genai::StopCriteria::HEURISTIC; - float length_penalty = 1.0; - size_t no_repeat_ngram_size = std::numeric_limits::max(); - - std::function early_finish = [](const Beam&) { - return false; - }; -}; - -struct Group { - std::vector ongoing; // Best beams in front - std::vector min_heap; // The worst of the best completed beams is the first - bool done = false; - - void finish(Beam&& beam, const Parameters& parameters) { - beam.score /= std::pow(float(beam.tokens.size()), parameters.length_penalty); - - min_heap.push_back(std::move(beam)); - std::push_heap(min_heap.begin(), min_heap.end(), greater); - if (min_heap.size() > parameters.group_size) { - std::pop_heap(min_heap.begin(), min_heap.end(), greater); - min_heap.pop_back(); - } - } - void is_done(const Parameters& parameters) { - if (min_heap.size() < parameters.group_size) { - return; - } - size_t cur_len = ongoing.front().tokens.size(); - float best_sum_logprobs = ongoing.front().score; - float worst_score = min_heap.front().score; - switch (parameters.stop_criteria) { - case ov::genai::StopCriteria::EARLY: - done = true; - return; - case ov::genai::StopCriteria::HEURISTIC: { - float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - case ov::genai::StopCriteria::NEVER: { - size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len; - float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - default: - throw std::runtime_error("Never reached"); - } - } -}; - -// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search -// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values -// are used for next inference. select_next_tokens() returns empty, if all groups are completed -struct GroupBeamSearcher { - Parameters parameters; - std::vector> prompts_groups; - - GroupBeamSearcher(Parameters parameters) : parameters{parameters}, prompts_groups{parameters.prompts.size()} { - if (parameters.no_repeat_ngram_size == 0) { - throw std::runtime_error("no_repeat_ngram_size must be positive"); - } - for (std::vector& prompts_groups : prompts_groups) { - prompts_groups.resize(parameters.n_groups); - for (Group& group : prompts_groups) { - group.ongoing.resize(parameters.group_size); - group.ongoing.front().score = 0.0; - } - } - } - - std::pair, std::vector> select_next_tokens(const ov::Tensor& logits) { - std::vector next_tokens; - std::vector next_beams; - - const size_t promts_size = parameters.prompts.size(); - - next_tokens.reserve(promts_size * parameters.n_groups * parameters.group_size); - next_beams.reserve(promts_size * parameters.n_groups * parameters.group_size); - - size_t beam_count = 0; - size_t prompt_id = 0; - for (std::vector& groups : prompts_groups) { - for (Group& group : groups) { - if (group.done) { - continue; - } - for (Beam& beam : group.ongoing) { - // beam.tokens.empty() holds for the first select_next_tokens() call. - // Every beam is constructed from the single batch at first call - if (beam.tokens.empty()) { - beam.global_beam_idx = prompt_id; - } else { - beam.global_beam_idx = beam_count; - ++beam_count; - } - } - } - - prompt_id += 1; - } - - for (int prompt_id = 0; prompt_id < promts_size; prompt_id++) { - const std::vector prompt = parameters.prompts[prompt_id]; - std::vector& groups = prompts_groups[prompt_id]; - auto [prompt_next_tokens, prompt_next_beams] = select_prompt_next_tokens(logits, prompt, groups); - - next_tokens.insert(next_tokens.end(), prompt_next_tokens.begin(), prompt_next_tokens.end()); - next_beams.insert(next_beams.end(), prompt_next_beams.begin(), prompt_next_beams.end()); - } - - return {next_tokens, next_beams}; - } - - std::pair, std::vector> select_prompt_next_tokens(const ov::Tensor& logits, - const std::vector& prompt, - std::vector& groups) { - std::vector next_tokens; - std::vector next_beams; - next_tokens.reserve(parameters.n_groups * parameters.group_size); - next_beams.reserve(parameters.n_groups * parameters.group_size); - - for (auto group = groups.begin(); group != groups.end(); ++group) { - if (group->done) { - continue; - } - std::vector candidates; - candidates.reserve(parameters.group_size * 2 * parameters.group_size); - for (const Beam& beam : group->ongoing) { - std::vector tokens = log_softmax(logits, beam.global_beam_idx); - for (auto prev_group = groups.cbegin(); prev_group != group; ++prev_group) { - for (const Beam& prev_beam : prev_group->ongoing) { - if (prev_beam.tokens.size() > beam.tokens.size()) { - tokens.at(size_t(prev_beam.tokens.back())).log_prob -= parameters.diversity_penalty; - } - } - } - std::vector full_text{prompt}; - full_text.insert(full_text.end(), beam.tokens.begin(), beam.tokens.end()); - if (full_text.size() > 1 && full_text.size() >= parameters.no_repeat_ngram_size) { - auto tail_start = full_text.end() - ptrdiff_t(parameters.no_repeat_ngram_size) + 1; - for (int64_t banned_token : kmp_search(full_text, {tail_start, full_text.end()})) { - tokens.at(size_t(banned_token)).log_prob = -std::numeric_limits::infinity(); - } - } - std::sort(tokens.begin(), tokens.end(), [](Token left, Token right) { - return left.log_prob > right.log_prob; // Most probable tokens in front - }); - size_t add_count = 0; - for (Token token : tokens) { - Beam new_candidate = beam; - new_candidate.score += token.log_prob; - new_candidate.tokens.push_back(token.idx); - if (parameters.early_finish(new_candidate)) { - group->finish(std::move(new_candidate), parameters); - } else { - candidates.push_back(std::move(new_candidate)); - ++add_count; - if (add_count == 2 * parameters.group_size) { - break; - } - } - } - } - // Sample 2 * group_size highest score tokens to get at least 1 non EOS token per beam - if (candidates.size() < 2 * parameters.group_size) { - throw std::runtime_error("No beams left to search"); - } - auto to_sort = candidates.begin() + ptrdiff_t(2 * parameters.group_size); - std::partial_sort(candidates.begin(), to_sort, candidates.end(), greater); - group->ongoing.clear(); - for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) { - if (parameters.eos_token_id == candidates.at(cand_idx).tokens.back()) { - // If beam_token does not belong to top num_beams tokens, it should not be added - if (cand_idx >= parameters.group_size) { - continue; - } - group->finish(std::move(candidates.at(cand_idx)), parameters); - } else { - group->ongoing.push_back(std::move(candidates.at(cand_idx))); - if (group->ongoing.size() == parameters.group_size) { - break; - } - } - } - group->is_done(parameters); - if (!group->done) { - for (const Beam& beam : group->ongoing) { - next_tokens.push_back(beam.tokens.back()); - next_beams.push_back(int32_t(beam.global_beam_idx)); - } - } - } - return {next_tokens, next_beams}; - } -}; - -// Consume group_beam_searcher because beams are consumed -std::vector>> finalize(GroupBeamSearcher&& group_beam_searcher) { - std::vector>> finalized; - finalized.resize(group_beam_searcher.prompts_groups.size()); - - for (size_t prompt_id = 0; prompt_id < group_beam_searcher.prompts_groups.size(); prompt_id++) { - std::vector& groups = group_beam_searcher.prompts_groups.at(prompt_id); - finalized.at(prompt_id).reserve(groups.size()); - - for (Group& group : groups) { - if (!group.done) { - for (Beam& beam : group.ongoing) { - group.finish(std::move(beam), group_beam_searcher.parameters); - } - } - finalized.at(prompt_id).push_back(std::move(group.min_heap)); - } - } - - return finalized; -} - -void reset_all_inputs_to_empty_tensors(ov::InferRequest& request) { - request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0})); - request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0})); - if (request.get_compiled_model().inputs().size() == 4) - request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0})); -} -} // namespace - -namespace ov { -namespace genai { - -std::pair beam_search(ov::InferRequest& lm, - ov::Tensor input_ids, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx) { - OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, - "number of beams should be divisible by number of groups"); - - auto batch_size = input_ids.get_shape().at(0); - auto sequence_length = input_ids.get_shape().at(1); - - // Initialize beam search. - const int64_t* prompt_data = input_ids.data(); - std::vector> prompts; - prompts.reserve(batch_size); - for (size_t batch = 0; batch < batch_size; batch++) { - size_t batch_offset = batch * sequence_length; - const int64_t* prompt_start = prompt_data + batch_offset; - prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); - } - - lm.set_tensor("input_ids", input_ids); - lm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) - lm.set_tensor("position_ids", *position_ids); - - ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); - lm.set_tensor("beam_idx", beam_idx); - - Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); - parameters.eos_token_id = config.eos_token_id; - parameters.n_groups = config.num_beam_groups; - parameters.group_size = config.num_beams / config.num_beam_groups; - parameters.diversity_penalty = config.diversity_penalty; - parameters.length_penalty = config.length_penalty; - parameters.stop_criteria = config.stop_criteria; - parameters.no_repeat_ngram_size = config.no_repeat_ngram_size; - GroupBeamSearcher group_beam_searcher{parameters}; - - std::vector next_tokens; - std::vector next_beams; - - // Reserve for performance counters. - std::vector new_token_times; - std::vector batch_sizes; - new_token_times.reserve(parameters.max_new_tokens); - batch_sizes.reserve(parameters.max_new_tokens); - - for (size_t length_count = 0; ; ++length_count) { - lm.infer(); - - std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); - new_token_times.emplace_back(std::chrono::steady_clock::now()); - batch_sizes.emplace_back(batch_size); - - if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { - // Break the cycle before masks are extended in update_attention_mask_with_beams. - // If generation is continued, attention_mask length should be equal to KV cache size. - break; - } - - size_t running_batch_size = next_tokens.size(); - // Set pointers - lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {running_batch_size, 1}, next_tokens.data()}); - lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {running_batch_size}, next_beams.data()}); - - // Set auxiliary inputs - update_attention_mask_with_beams(lm.get_tensor("attention_mask"), next_beams); - if (position_ids.has_value()) - update_position_ids(lm.get_tensor("position_ids"), lm.get_tensor("attention_mask")); - } - - reset_all_inputs_to_empty_tensors(lm); - - auto scores_comparator = [](Beam& left, Beam& right) { - return (left.score > right.score); - }; - - auto result = finalize(std::move(group_beam_searcher)); - ov::genai::EncodedResults results; - int32_t res_selected_beam_idx = 0; - results.scores.reserve(config.num_return_sequences * result.size()); - results.tokens.reserve(config.num_return_sequences * result.size()); - auto& raw_perf_counters = results.perf_metrics.raw_metrics; - raw_perf_counters.m_new_token_times = new_token_times; - raw_perf_counters.m_batch_sizes = batch_sizes; - - // align output with HF - for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { - auto prompt_group = result.at(prompt_id); - std::vector> plain_beams; - plain_beams.reserve(parameters.n_groups * parameters.group_size); - for (std::vector& group : prompt_group) { - for (Beam& beam : group) { - plain_beams.push_back(beam); - } - } - assert(config.num_return_sequences <= plain_beams.size()); - std::partial_sort( - plain_beams.begin(), - plain_beams.begin() + config.num_return_sequences, - plain_beams.end(), - scores_comparator - ); - res_selected_beam_idx = plain_beams.at(0).get().global_beam_idx; - for ( - auto beam = plain_beams.begin(); - beam != plain_beams.begin() + config.num_return_sequences; - ++beam - ) { - results.scores.push_back(beam->get().score); - results.tokens.push_back(std::move(beam->get().tokens)); - } - } - - return {results, res_selected_beam_idx}; -} - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 6d9aae30fa..33180a9199 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -24,27 +24,23 @@ namespace ov { namespace genai { -std::pair beam_search( - ov::InferRequest& lm, - ov::Tensor prompts, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx -); - class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; bool is_chat_conversation = false; bool m_trust_encoded_history = true; - std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; std::vector m_tokenized_chat_history; ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; - size_t m_to_remove_from_hist = 0; size_t m_kv_cache_seq_length_axis = 2; + Sampler m_sampler; + // Tail of previous output in chat mode is missing in KV cache, let's keep it + std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; StatefulLLMPipeline( const ov::InferRequest& request, @@ -75,7 +71,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { const std::string& device, const ov::AnyMap& config, const ov::genai::GenerationConfig& generation_config - ) : LLMPipelineImplBase(tokenizer, generation_config) { + ) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) { ov::Core core; ov::CompiledModel compiled_model; auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config); @@ -96,6 +92,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // If eos_token_id was not provided, take value if (m_generation_config.eos_token_id == -1) m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + + m_sampler.set_seed(m_generation_config.rng_seed); } StatefulLLMPipeline( @@ -151,35 +149,44 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; + size_t trusted_history_length = 0; if (!m_tokenized_chat_history.empty()) { std::set stop_tokens = config.stop_token_ids; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); - m_trust_encoded_history = last_same_hist_token == SIZE_MAX; + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + m_trust_encoded_history = trusted_history_length == SIZE_MAX; } if (m_tokenized_chat_history.empty()) { encoded_input = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.input_ids.data() + last_same_hist_token); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.input_ids.data() + trusted_history_length); ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape()); std::fill_n(new_attention_mask.data(), new_tensor.get_shape()[1], 1); encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}); new_tensor.copy_to(encoded_input.input_ids); encoded_input.attention_mask = new_attention_mask; - - m_selected_beam = std::nullopt; + m_last_disappeared_token = std::nullopt; } else { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; + m_tokenized_chat_history.clear(); m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size()); std::copy_n(new_chat_tokens.input_ids.data(), new_chat_tokens.input_ids.get_size(), @@ -261,6 +268,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + // Tail of previous output in chat mode is missing in KV cache. + if (m_last_disappeared_token.has_value()) { + attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1); + input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token); + } + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; // If eos_token_id was not provided, take value from default m_generation_config @@ -281,10 +294,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } auto batch_size = input_ids.get_shape().at(0); - if ((batch_size != 1 || !(config.is_greedy_decoding() || config.is_multinomial())) && streamer_ptr) { - OPENVINO_THROW("Currently streaming is possible only with batch size=1 and " - "only for greedy or multinomial decoding"); - } + OPENVINO_ASSERT(streamer_ptr == nullptr || batch_size == 1 && config.num_return_sequences == 1 && + (config.is_greedy_decoding() || config.is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); auto num_inputs = m_model_runner.get_compiled_model().inputs().size(); OPENVINO_ASSERT(num_inputs == 4 || num_inputs == 3, "Model should have 3 or 4 inputs: " @@ -292,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); - ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller); + ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; @@ -302,10 +314,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // Between subsequent runs attention_mask should not be modified. auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist; + + kv_cache_len = atten_mask_history.get_shape()[1] - m_kv_history_manager.num_tokens_to_remove_from_kv_cache; ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; - auto start_atten_hst = atten_mask_history.data() + kv_cache_len * (*m_selected_beam); + auto start_atten_hst = atten_mask_history.data(); + std::copy(start_atten_hst, start_atten_hst + kv_cache_len, new_atten_mask.data()); std::copy(attention_mask.data(), attention_mask.data() + prompt_len, @@ -315,6 +329,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { concatenated_attention_mask = attention_mask; } + size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1]; + bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { @@ -328,48 +344,55 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && !m_trust_encoded_history) { m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); } - ov::genai::EncodedResults result; - if (config.is_beam_search() && is_chat_conversation) { - std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask, - config, position_ids, m_selected_beam); - } else { - std::vector requests; - size_t block_size = 1; - bool enable_prefix_caching = false; - - for (size_t request_id = 0; request_id < batch_size; request_id++) { - SequenceGroup::Ptr sequence_group; - if (is_chat_conversation) { - ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); - sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); - } else { - size_t seq_len = input_ids.get_shape().at(1); - size_t batch_offset = request_id * seq_len; - const int64_t* prompt_start = input_ids.data() + batch_offset; - std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); + std::vector requests; + size_t block_size = 1; + bool enable_prefix_caching = false; - sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); - } + for (size_t request_id = 0; request_id < batch_size; request_id++) { + SequenceGroup::Ptr sequence_group; + if (is_chat_conversation) { + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); + } else { + size_t seq_len = input_ids.get_shape().at(1); + size_t batch_offset = request_id * seq_len; + const int64_t* prompt_start = input_ids.data() + batch_offset; + std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - sequence_group->set_sequence_group_ptr(sequence_group); - requests.push_back(sequence_group); + sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); } - Sampler sampler = Sampler(m_tokenizer); - std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, - sampler, requests, position_ids, std::nullopt, m_selected_beam); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); } + if (m_sampler.get_seed() != config.rng_seed) { + m_sampler.set_seed(config.rng_seed); + } + + ov::genai::EncodedResults result; + std::tie(result, m_last_disappeared_token) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, + streamer_ptr, m_sampler, requests, position_ids, std::nullopt); + if (is_chat_conversation) { + // force remove from kv_cache last answer + if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { + m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; + } + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); } else { reset_kv_state(); - m_selected_beam = std::nullopt; + m_last_disappeared_token = std::nullopt; } + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -383,10 +406,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history = {}; @@ -404,10 +427,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history.clear(); @@ -581,9 +604,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector plain_replies; std::vector plain_scores; for (GenerationResult& res : generated) { - if (GenerationStatus::FINISHED != res.m_status) { - OPENVINO_THROW("Got unfinished GenerationStatus"); - } + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } @@ -639,9 +660,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector> plain_tokens; std::vector plain_scores; for (EncodedGenerationResult& res : generated) { - if (GenerationStatus::FINISHED != res.m_status) { - OPENVINO_THROW("Got unfinished GenerationStatus"); - } + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 090aed9650..6f4f124894 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -407,7 +407,8 @@ ov::genai::ModelConfigDesc get_modeldesc_from_json(const std::filesystem::path& if (config_data.contains("_name_or_path")) { desc.name_or_path = config_data["_name_or_path"].get(); } - desc.num_key_value_heads = config_data["num_key_value_heads"].get(); + desc.num_key_value_heads = config_data.contains("num_key_value_heads") + ? config_data["num_key_value_heads"].get() : -1; return desc; } @@ -1102,6 +1103,11 @@ EncodedResults StaticLLMPipeline::generate( m_kvcache_request.get_tensor(output_name).copy_to(kvcache_in_slice); } } + + if (streamer_ptr) { + streamer_ptr->end(); + } + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. auto& metrics = results.perf_metrics; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 3ab041fa58..17a20dd961 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -9,12 +9,11 @@ #include #include +#include "utils.hpp" +#include "debug_utils.hpp" #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "debug_utils.hpp" - -#include "utils.hpp" namespace ov { namespace genai { @@ -51,7 +50,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector get_lm_encoded_results( +std::pair> get_lm_encoded_results( ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, @@ -59,41 +58,56 @@ std::pair get_lm_encoded_results( Sampler& sampler, std::vector sequence_groups, std::optional position_ids, - std::optional m_embedding, - std::optional selected_beam_idx + std::optional m_embedding ) { std::vector generations; for (SequenceGroup::Ptr sequence_group : sequence_groups) { generations.push_back(std::make_shared(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters())); } + auto active_sequence_groups{sequence_groups}; + + auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() { + GenerationHandle& handle = generations.at(0); + if (streamer_ptr && handle->can_read()) { + std::unordered_map token = handle->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + handle->drop(); + break; + } + } + } + + // free non running requests + auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(), + [](SequenceGroup::Ptr sg) -> bool { + return sg->has_finished() || sg->out_of_memory() || sg->handle_dropped(); + }); + active_sequence_groups.erase(removed_it, active_sequence_groups.end()); + }; + ov::Shape prompts_shape = input_ids.get_shape(); const size_t batch_size = prompts_shape[0]; // Initialize results and performance metrics. + EncodedResults results; auto& raw_perf_counters = results.perf_metrics.raw_metrics; raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }}; // Initialize inputs - if (m_embedding.has_value()) - m_llm.set_tensor("inputs_embeds", input_ids); - else - m_llm.set_tensor("input_ids", input_ids); - + m_llm.set_tensor(m_embedding.has_value() ? "inputs_embeds" : "input_ids", input_ids); m_llm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) m_llm.set_tensor("position_ids", *position_ids); ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); + std::fill_n(beam_idx.data(), batch_size, 0); m_llm.set_tensor("beam_idx", beam_idx); + // "Prompt" phase + const auto infer_start = std::chrono::steady_clock::now(); m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); @@ -109,7 +123,6 @@ std::pair get_lm_encoded_results( for (auto& sequence_group : sequence_groups) { sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len); sequence_group->schedule_tokens(sequence_len); - } std::map beam_offets; @@ -117,27 +130,11 @@ std::pair get_lm_encoded_results( beam_offets.insert({sequence_groups.at(i)->get_request_id(), i}); SamplerOutput sampler_output = sampler.sample(sequence_groups, logits); + stream_generated_tokens(); - auto active_sequence_groups{sequence_groups}; - auto get_active_sequence_groups = [](SequenceGroup::Ptr sg) { return sg->has_finished(); }; - - active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(), - active_sequence_groups.end(), - get_active_sequence_groups), - active_sequence_groups.end()); - - auto stream_generated_tokens = [&streamer_ptr, &generations]() { - if (streamer_ptr && generations.at(0).get()->can_read()) { - std::unordered_map token = generations.at(0).get()->back(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (!streamer_ptr->put(gen_token)) { - break; - } - } - } - }; + // "Generation" phase - while (active_sequence_groups.size() > 0) { + while (!active_sequence_groups.empty()) { size_t total_num_tokens = 0; for (auto& sequence_group : active_sequence_groups) { @@ -172,26 +169,19 @@ std::pair get_lm_encoded_results( // apply strides to shift to a next sequence input_ids_data += num_scheduled_tokens; - // for different sequences iteration of beams started from 0, but we collect it to one input_ids# + // for different sequences iteration of beams started from 0, but we collect it to one input_ids next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id())); } } - for (size_t i = 0; i < sequence_groups.size(); i++) { - if (i == 0) - beam_offets[sequence_groups.at(i)->get_request_id()] = 0; - else { - beam_offets[sequence_groups.at(i)->get_request_id()] = sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i -1]; - } + for (size_t i = 0; i < active_sequence_groups.size(); i++) { + beam_offets[active_sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]); } if (m_embedding.has_value()) { const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids); - - m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape()); m_llm.set_tensor("inputs_embeds", embed_prompt_tensor); } else { - m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape()); m_llm.set_tensor("input_ids", new_input_ids); } @@ -201,7 +191,6 @@ std::pair get_lm_encoded_results( update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask")); } - m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens }); m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); const auto infer_start = std::chrono::steady_clock::now(); @@ -213,42 +202,38 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); - stream_generated_tokens(); - sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits")); - - active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(), - active_sequence_groups.end(), - get_active_sequence_groups), - active_sequence_groups.end()); + stream_generated_tokens(); } - // to stream last token - stream_generated_tokens(); - if (streamer_ptr) { + if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } - - size_t next_selected_beam = 0; - for (size_t i = 0; i < sequence_groups.size(); i++) { - auto request = sequence_groups[i]; - auto generation_outputs = generations[i]->read_all(); - - std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; - results.tokens.push_back(std::move(generation_output.generated_ids)); - results.scores.push_back(generation_output.score); + + for (auto& sequence_group : sequence_groups) { + auto sampling_params = sequence_group->get_sampling_parameters(); + const auto& sequences = sequence_group->get_finished_sequences(); + size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size()); + + for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) { + const auto & sequence = sequences[seq_id]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs(); + + results.tokens.push_back(sequence->get_generated_ids()); + results.scores.push_back(score); } - // next_selected_beam = sampler.last_selected_beam(request); } - return {results, next_selected_beam}; + for (SequenceGroup::Ptr sequence_group : sequence_groups) + sampler.clear_request_info(sequence_group->get_request_id()); + + // it is not saved in KV cache, we need to add it for some cases + std::optional last_token_of_best_sequence = std::nullopt; + if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped()) + last_token_of_best_sequence = results.tokens[0].back(); + + return {results, last_token_of_best_sequence}; } } // namespace genai -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index fa6692ede0..c31cffb9bc 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,13 +8,9 @@ namespace ov { namespace genai { -std::pair get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, - const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, - std::optional position_ids, std::optional m_embedding, std::optional selected_beam_idx); - -void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams); - -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask); +std::pair> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, + const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, + std::optional position_ids, std::optional m_embedding); } } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index f77463d767..9c18dc7721 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -85,75 +85,63 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin return clean_text; } +std::vector encode_and_process_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) { + // encode stop_string + std::string stop_string_copy = stop_string; + ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string_copy, ov::genai::add_special_tokens(false)).input_ids; + size_t tensor_size = ov_encoded_stop_string.get_size(); + std::vector encoded_stop_string(tensor_size); + std::copy_n(ov_encoded_stop_string.data(), tensor_size, encoded_stop_string.begin()); + return encoded_stop_string; +} + +struct MatchStopStringResult { + size_t to_remove = 0; + // int64_t last_token_id = 0; + // bool is_to_update_last_token = false; + bool is_matched = false; +}; + // Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. -int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set & stop_strings) { - /* - For catching stop_string hit we run comparisons character-wise to catch cases where stop string - overlaps with part of another token on both sides or is just a part of a single token. - For every stop_string we iterate over generated tokens starting from the last one and going backwards. - Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token. - After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token. - Its characters are compared to the stop_string character at a current_position - (position of a character in the stop_string counting from the last one) - at the beginning position is 0. - When characters match we increase current_position and check if we have a full match already, if not we continue. - If we have already matched some characters (current_position > 0) and next character is not matching - before we reach the full match, then we reset current_position to 0. - */ - std::string prefix = "a"; - auto prefix_ov = tokenizer.encode(prefix).input_ids; - std::vector prefix_tokens(prefix_ov.data(), prefix_ov.data() + prefix_ov.get_size()); - std::string suffix = "b"; - auto suffix_ov = tokenizer.encode(suffix).input_ids; - std::vector suffix_tokens(suffix_ov.data(), suffix_ov.data() + suffix_ov.get_size()); - - // Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here - // and get suffix string that will actually be part of the decoded string so we can remove it correctly - auto wrapped_suffix_tokens = suffix_tokens; - wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end()); - std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens); - auto wrapper_pos = wrapped_suffix.find(prefix); - suffix = wrapped_suffix.substr(wrapper_pos + prefix.size()); - - for (auto stop_string: stop_strings) { - int current_position = 0; - int num_matched_tokens = 0; - // Getting reverse iterator to check tokens starting from the last one generated and going backwards - auto generated_tokens_rit = generated_tokens.rbegin(); - std::vector tokens_buffer; - while (generated_tokens_rit != generated_tokens.rend()) { - num_matched_tokens++; - tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit); - - std::vector wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens); - std::string wrapped_text = tokenizer.decode(wrapped_tokens); - std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix); - - if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "�") == 0))) { - generated_tokens_rit++; - continue; - } else { - tokens_buffer.clear(); - } - // Checking clean_text characters starting from the last one - for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) { - // On character match increment current_position for the next comparisons - if (*clean_text_rit == *(stop_string.rbegin() + current_position)) { - current_position++; - // If this is the last character from the stop_string we have a match - if ((stop_string.rbegin() + current_position) == stop_string.rend()) { - return num_matched_tokens; - } - } else if (current_position) { - // Already found matching characters, but the last one didn't match, so we reset current_position - current_position = 0; - // Looking for the match will start over from this character so we decrement iterator - clean_text_rit--; +MatchStopStringResult match_stop_string(Tokenizer& tokenizer, + const TokenIds& generated_tokens, + const std::pair>& stop_strings, + bool is_include_to_output) { + MatchStopStringResult result; + if (generated_tokens.size() >= stop_strings.first) { + size_t offset = generated_tokens.size() - stop_strings.first; + TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end()); + std::string decoded_buffer = tokenizer.decode(buffer); + for (const auto& stop_string : stop_strings.second) { + auto pos = decoded_buffer.find(stop_string); + if (pos != std::string::npos) { + result.is_matched = true; + + auto stop_string_len = is_include_to_output ? stop_string.length() : 0; + decoded_buffer = decoded_buffer.substr(0, pos + stop_string_len); + // to remove word splitting symbols from tail + while (decoded_buffer.back() == ' ' || decoded_buffer.back() == '\n') { + decoded_buffer.pop_back(); + } + if (decoded_buffer.empty()) { + result.to_remove = buffer.size(); + return result; } + + // find token cnt to be removed from sequence by decoding token by token + std::string decoded_partially_string = ""; + for (size_t i = 0; i < buffer.size(); ++i) { + decoded_partially_string += tokenizer.decode(TokenIds{buffer[i]}); + if (decoded_partially_string.find(decoded_buffer) != std::string::npos) { + result.to_remove = buffer.size() - i - 1; + break; + } + } + return result; } - generated_tokens_rit++; } } - return 0; + return result; } // Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. @@ -245,7 +233,9 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } -void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) { +void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, + SamplerOutput& sampler_output, + const std::pair>& stop_strings) { assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups; @@ -392,19 +382,17 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa // There's probably a better way to do that, than copying whole vector... std::vector token_ids = candidate.m_sequence->get_generated_ids(); token_ids.push_back(candidate.m_token_id); - int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings); - if (num_last_matched_tokens) { + auto match_result = match_stop_string(m_tokenizer, token_ids, stop_strings, m_parameters.include_stop_str_in_output); + if (match_result.is_matched) { // If beam_token does not belong to top num_beams tokens, it should not be added if (cand_idx >= group_size) continue; - if(!m_parameters.include_stop_str_in_output) { - // remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point) - candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1); - } + // remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point) + candidate.m_sequence->remove_last_tokens(match_result.to_remove); // try to finish candidate - try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output); + try_to_finish_candidate(group, candidate); continue; } } @@ -576,10 +564,11 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen } if (!sampling_params.stop_strings.empty()) { - int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); - if (num_matched_last_tokens) { - if (!sampling_params.include_stop_str_in_output) - running_sequence->remove_last_tokens(num_matched_last_tokens); + auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id()); + auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output); + if (match_result.is_matched) { + running_sequence->remove_last_tokens(match_result.to_remove); + running_sequence->set_status(SequenceStatus::FINISHED); running_sequence->set_finish_reason(GenerationFinishReason::STOP); dropped_seq_ids.push_back(running_sequence->get_id()); @@ -741,6 +730,19 @@ float get_p_prime(Sequence::Ptr& running_sequence, return p_prime; } +std::pair> +process_stop_strings(const std::set& stop_strings, Tokenizer& tokenizer) { + std::pair> result; + for (const auto& stop_string : stop_strings) { + auto encoded_stop_string = encode_and_process_string(stop_string, tokenizer); + if (result.first < encoded_stop_string.size()) { + result.first = encoded_stop_string.size(); + } + result.second.insert(stop_string); + } + return result; +} + SamplerOutput Sampler::sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled) { @@ -764,6 +766,12 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, if (!m_logit_processors.count(request_id)) { m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())}); } + if (!m_stop_strings.count(request_id)) { + auto processed_stop_string = process_stop_strings(sampling_params.stop_strings, m_tokenizer); + m_stop_strings.insert({request_id, processed_stop_string}); + sequence_group->set_stream_window_size(processed_stop_string.first); + } + auto& stop_strings = m_stop_strings.at(request_id); auto& logit_processor = m_logit_processors.at(request_id); const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens; ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data); @@ -873,7 +881,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, } // current algorithm already adds new tokens to running sequences and - m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output); + m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings); // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); @@ -886,8 +894,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; - sequence_group->notify_handle(num_output_token_to_push); + sequence_group->notify_handle(); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } @@ -926,6 +933,7 @@ void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig void Sampler::clear_request_info(uint64_t request_id) { m_beam_search_info.erase(request_id); m_logit_processors.erase(request_id); + m_stop_strings.erase(request_id); } int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) { diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 0f7876cbf9..981e11560f 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -55,8 +55,11 @@ class Sampler { std::map m_beam_search_info; std::mt19937 rng_engine; + size_t seed = rng_engine.default_seed; // { request_id, logit_processor } std::map m_logit_processors; + // { request_id, { max_encoded_len, { stop_strings }}} + std::map>> m_stop_strings; Tokenizer m_tokenizer; @@ -65,7 +68,11 @@ class Sampler { Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); - void set_seed(size_t seed) { rng_engine.seed(seed); } + void set_seed(size_t new_seed) { + rng_engine.seed(new_seed); + seed = new_seed; + } + size_t get_seed() { return seed; } void clear_request_info(uint64_t request_id); @@ -115,7 +122,7 @@ class Sampler::GroupBeamSearcher { public: explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer); - void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output); + void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::pair>& stop_strings); void finalize(SamplerOutput& sampler_output); std::map get_beam_idxs(); }; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 6755255fe8..220e93c032 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -126,23 +126,28 @@ class Sequence { } } - GenerationOutput get_last_generation_output(size_t token_cnt = 1) { + GenerationOutput get_last_generation_output(size_t token_cnt = 1, size_t num_token_to_ignore = 0) { GenerationOutput output; - OPENVINO_ASSERT(m_generated_ids.size()); - output.score = get_cumulative_log_probs(); + if (token_cnt > 0) { + OPENVINO_ASSERT(m_generated_ids.size()); + output.score = get_cumulative_log_probs(); - auto generated_token_id = get_generated_ids(); - auto generated_log_probs = get_generated_log_probs(); + auto generated_token_id = get_generated_ids(); + auto generated_log_probs = get_generated_log_probs(); - OPENVINO_ASSERT(get_generated_len() >= token_cnt); - auto offset = get_generated_len() - token_cnt; + OPENVINO_ASSERT(get_generated_len() >= token_cnt); + if (get_generated_len() > num_token_to_ignore) { + auto offset = get_generated_len() - token_cnt - num_token_to_ignore; + auto offset_back = get_generated_len() - num_token_to_ignore; - std::vector token_id(generated_token_id.begin() + offset, generated_token_id.end()); - std::vector log_probs(generated_log_probs.begin() + offset, generated_log_probs.end()); + std::vector token_id(generated_token_id.begin() + offset, generated_token_id.begin() + offset_back); + std::vector log_probs(generated_log_probs.begin() + offset, generated_log_probs.begin() + offset_back); - output.generated_ids = token_id; - output.generated_log_probs = log_probs; - output.finish_reason = get_finish_reason(); + output.generated_ids = token_id; + output.generated_log_probs = log_probs; + output.finish_reason = get_finish_reason(); + } + } return output; } @@ -173,8 +178,6 @@ class Sequence { return score; } - - // Each KV block can be uniquely identified by void set_sequence_group_ptr(std::shared_ptr sequence_group) { m_sequence_group = sequence_group; @@ -221,6 +224,8 @@ class SequenceGroup { // flag to enable/disable token generation, e.g. in speculative decoding scenario bool m_is_gen_paused = false; + size_t m_num_streamed_tokens = 0, m_stream_window_size = 0; + SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -332,14 +337,16 @@ class SequenceGroup { std::vector get_finished_sequences() const { std::vector finished_seqs; for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) { - if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory()) { + if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory() || handle_dropped()) { finished_seqs.push_back(m_sequences[seq_id]); } } - // do we need to sort sequences here or sampler can handle it for us? - std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) { - return s1->get_beam_search_score(m_sampling_params) > s2->get_beam_search_score(m_sampling_params); + std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) -> bool { + bool is_beam_search = m_sampling_params.is_beam_search(); + const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_probs(); + const float score_2 = is_beam_search ? s2->get_beam_search_score(m_sampling_params) : s2->get_cumulative_log_probs(); + return score_1 > score_2; }); return finished_seqs; @@ -454,6 +461,10 @@ class SequenceGroup { size_t get_num_tokens_to_validate() { return m_num_validation_tokens; } + + void set_stream_window_size(size_t k) { + m_stream_window_size = k; + } size_t get_num_available_tokens_for_batching() const { OPENVINO_ASSERT(!has_finished(), "Internal error: this function cannot be called on finished sequence group"); @@ -571,7 +582,7 @@ class SequenceGroup { m_generation_stream->set_generation_status(status); } - bool handle_dropped() { + bool handle_dropped() const { return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE; } @@ -601,7 +612,7 @@ class SequenceGroup { for (auto& sequence : m_sequences) { // todo: check seq.is_finished() to generate without several // or is it ok to use padding? - auto output = sequence->get_last_generation_output(token_cnt); + auto output = sequence->get_last_generation_output(token_cnt, m_stream_window_size); if (m_sampling_params.echo && !m_has_echoed) { output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end()); output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end()); @@ -612,24 +623,36 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } - void notify_handle(size_t num_output_token_to_push = 0) { + void notify_handle() { if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { set_generation_status(GenerationStatus::FINISHED); } // For beam search streaming is not available, so we notify only upon finishing - if(m_sampling_params.is_beam_search()) { + if (m_sampling_params.is_beam_search()) { if (has_finished() || out_of_memory()) { push_outputs(); } } else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) { // We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output // (after stop string is detected its tokens are already sent) - if (num_total_seqs() == 1 && - (m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) { - if (num_output_token_to_push) - push_partial_outputs(num_output_token_to_push); + if (num_total_seqs() == 1) { + const auto generated_len = m_sequences.front()->get_generated_len(); + if (has_finished()) { + m_stream_window_size = 0; + } + if (generated_len <= (m_num_streamed_tokens + m_stream_window_size)) { + return; + } + // speculative decoding draft handling + if (generated_len < m_num_streamed_tokens) { + m_num_streamed_tokens = generated_len; + } + OPENVINO_ASSERT(generated_len >= (m_num_streamed_tokens + m_stream_window_size)); + size_t num_output_token_to_push = generated_len - m_num_streamed_tokens - m_stream_window_size; + push_partial_outputs(num_output_token_to_push); + m_num_streamed_tokens += (num_output_token_to_push); } else if (has_finished() || out_of_memory()) { push_outputs(); } diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 314a7ffa4d..5938b55f6c 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -52,4 +52,4 @@ void TextCallbackStreamer::end() { ov::genai::StreamerBase::~StreamerBase() = default; } // namespace genai -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index a03b0deccb..6f0872ad1b 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -25,4 +25,4 @@ class TextCallbackStreamer: public StreamerBase { }; } // namespace genai -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 96191387cd..57225e60ff 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -28,6 +28,21 @@ enum class GenerationChatInputsType { ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs }; +struct HistoryRemoveManager +{ + size_t num_tokens_to_remove_from_kv_cache = 0; + size_t trusted_history_length = 0; + + bool does_kv_cache_need_to_update() { + return (trusted_history_length > 0 || num_tokens_to_remove_from_kv_cache > 0); + } + + void reset() { + num_tokens_to_remove_from_kv_cache = 0; + trusted_history_length = 0; + } +}; + Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index 8175d44b16..e53be4e1cd 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -42,11 +42,12 @@ class InputsEmbedder::IInputsEmbedder { std::string m_templated_chat_history; // Tokenized chat history std::vector m_tokenized_history; - // The number of elements, which need to remove from the end of KV cache - // removed elements will be added to inputs_ids - size_t m_to_remove_from_hist = 0; // Tail of previous output for LM in chat mode is missing in KV cache. std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; public: virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0; @@ -63,22 +64,26 @@ class InputsEmbedder::IInputsEmbedder { return m_tokenized_history; } - size_t get_amount_to_remove_from_hist() const { - return m_to_remove_from_hist; + size_t get_num_tokens_to_remove_from_hist() const { + return m_kv_history_manager.num_tokens_to_remove_from_kv_cache; } - void update_tokenized_history(std::vector encoded_result, bool token_will_disappear) { + void update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len) { + if (is_beam_search) { + m_kv_history_manager.trusted_history_length = m_tokenized_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len; + } else { + m_kv_history_manager.reset(); + } + + m_last_disappeared_token = last_disappeared_token; + std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history)); - m_to_remove_from_hist = 0; - if (token_will_disappear) - m_last_disappeared_token = encoded_result.back(); - else - m_last_disappeared_token = std::nullopt; } virtual void start_chat(const std::string& system_message) { m_is_chat_conversation = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); if (!m_tokenized_history.empty()) { m_history.clear(); m_templated_chat_history.clear(); @@ -101,7 +106,7 @@ class InputsEmbedder::IInputsEmbedder { virtual void finish_chat() { m_is_chat_conversation = false; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_history.clear(); m_templated_chat_history.clear(); @@ -171,24 +176,32 @@ class InputsEmbedder::IInputsEmbedder { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; + size_t trusted_history_length = 0; if (!m_tokenized_history.empty()) { std::set stop_tokens = {m_tokenizer.get_eos_token_id()}; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); } if (m_tokenized_history.empty()) { encoded_input_ids = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_history.size() - last_same_hist_token; - // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it - m_to_remove_from_hist -= m_last_disappeared_token.has_value() ? 1 : 0; + + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(), - {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.data() + last_same_hist_token); + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.data() + trusted_history_length); encoded_input_ids = ov::Tensor(new_chat_tokens.get_element_type(), - {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}); + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}); new_tensor.copy_to(encoded_input_ids); } else { encoded_input_ids = utils::subtract_chat_tokenized_inputs( @@ -1192,12 +1205,12 @@ std::vector InputsEmbedder::get_tokenized_history() const { return m_impl->get_tokenized_history(); } -void InputsEmbedder::update_tokenized_history(std::vector encoded_result, bool token_will_disappear) { - return m_impl->update_tokenized_history(encoded_result, token_will_disappear); +void InputsEmbedder::update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len) { + return m_impl->update_tokenized_history(encoded_result, last_disappeared_token, is_beam_search, last_answer_len); } -size_t InputsEmbedder::get_amount_to_remove_from_hist() const { - return m_impl->get_amount_to_remove_from_hist(); +size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const { + return m_impl->get_num_tokens_to_remove_from_hist(); } Tokenizer InputsEmbedder::get_tokenizer() const { diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 8c84c6ad43..1d72b742ab 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -43,11 +43,11 @@ class InputsEmbedder { // returns tokenized chat history std::vector get_tokenized_history() const; - // add new results to tokenized chat history - void update_tokenized_history(std::vector encoded_result, bool token_will_disappear); + // add new results to tokenized history + void update_tokenized_history(const std::vector& encoded_result, std::optional last_disappeared_token, bool is_beam_search, size_t last_answer_len); // returns amount of elements, which need to remove from the end of the KV cache - size_t get_amount_to_remove_from_hist() const; + size_t get_num_tokens_to_remove_from_hist() const; // starts chat and adds optional system_message to chat history void start_chat(const std::string& system_message); diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 0d7aebc506..d625485205 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -67,6 +67,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { float m_load_time_ms = 0; // Axis num in kv cache from m_language model, which contains information about history len size_t m_kv_cache_seq_length_axis = 2; + // Component for applying sampling to lm outputs + Sampler m_sampler; VLMPipelineImpl( const std::filesystem::path& models_dir, @@ -105,6 +107,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMPipelineImpl( @@ -140,6 +145,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMDecodedResults generate( @@ -161,7 +169,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); auto end_get_inputs_embeds = std::chrono::steady_clock::now(); - auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist(); + auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist(); ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt); std::vector requests; @@ -195,8 +203,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { }, }, streamer); - OPENVINO_ASSERT((generation_config.is_greedy_decoding() || generation_config.is_multinomial() || !streamer_ptr), - "Currently streaming is possible only for greedy or multinomial decoding"); + OPENVINO_ASSERT(streamer_ptr == nullptr || generation_config.num_return_sequences == 1 && + (generation_config.is_greedy_decoding() || generation_config.is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, { 1, history_size + inputs_embeds_size }}; std::fill_n(new_atten_mask.data(), new_atten_mask.get_size(), 1); @@ -204,12 +213,14 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }}; std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size); - Sampler sampler = Sampler(m_tokenizer); + if (m_sampler.get_seed() != generation_config.rng_seed) { + m_sampler.set_seed(generation_config.rng_seed); + } ov::genai::EncodedResults encoded_result; - int32_t m_selected_beam = 0; - std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests, - position_ids, m_embedding, std::nullopt); + std::optional last_disappeared_token; + std::tie(encoded_result, last_disappeared_token) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, + position_ids, m_embedding); auto decode_start_time = std::chrono::steady_clock::now(); VLMDecodedResults decoded; @@ -219,6 +230,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } auto decode_end_time = std::chrono::steady_clock::now(); + m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], last_disappeared_token, generation_config.is_beam_search(), + m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size)); + std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) { m_inputs_embedder->update_chat_history(decoded_results); @@ -245,8 +259,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { decoded.perf_metrics.m_evaluated = false; decoded.perf_metrics.evaluate_statistics(generate_start_time); - m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH); - return decoded; } diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 50ee452f5c..163a00192e 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -125,6 +125,34 @@ def get_beam_search_with_multiple_stop_strings_no_match() -> GenerationConfig: generation_config.include_stop_str_in_output = True return generation_config +def get_greedy_stop_strings_exclude_from_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines" } + generation_config.include_stop_str_in_output = False + return generation_config + +def get_greedy_stop_strings_include_to_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines" } + generation_config.include_stop_str_in_output = True + return generation_config + +def get_greedy_n_stop_strings_exclude_from_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines", "manage" } + generation_config.include_stop_str_in_output = False + return generation_config + +def get_greedy_n_stop_strings_include_to_output() -> GenerationConfig: + generation_config = GenerationConfig() + generation_config.max_new_tokens = 30 + generation_config.stop_strings = { "machines", "manage" } + generation_config.include_stop_str_in_output = True + return generation_config + def get_multinomial_temperature() -> GenerationConfig: generation_config = GenerationConfig() generation_config.do_sample = True @@ -359,9 +387,14 @@ def compare_results(hf_result: GenerationResult, ov_result: GenerationResult, ge # Note, that for fp32 / fp16 models scores are different less than 0.001 assert abs(hf_score - ov_score) < 0.02 - assert len(hf_result.m_generation_ids) == len(ov_result.m_generation_ids) - for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): - assert hf_text == ov_text + if not generation_config.include_stop_str_in_output and len(generation_config.stop_strings) > 0: + assert len(hf_result.m_generation_ids) >= len(ov_result.m_generation_ids) + for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): + assert ov_text in hf_text + else: + assert len(hf_result.m_generation_ids) == len(ov_result.m_generation_ids) + for hf_text, ov_text in zip(hf_result.m_generation_ids, ov_result.m_generation_ids): + assert hf_text == ov_text def save_ov_model_from_optimum(model, hf_tokenizer, models_path: Path): model.save_pretrained(models_path) diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index 3dac3f8b00..bc5324b211 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -optimum-intel @ git+https://github.com/huggingface/optimum-intel.git +optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a numpy<2.0.0; sys_platform == 'darwin' onnx==1.17.0 pytest diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py index 9260e671d6..d9661e538b 100644 --- a/tests/python_tests/test_chat_generate_api.py +++ b/tests/python_tests/test_chat_generate_api.py @@ -187,10 +187,13 @@ def test_set_chat_template(): model_descr = get_chat_models_list()[0] model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}") + config = ov_genai.GenerationConfig() + config.max_new_tokens = 1 + config.do_sample = False pipe.start_chat() - generated = pipe.generate("a", max_new_tokens=1) + generated = pipe.generate("a", config) pipe.finish_chat() - reference = pipe.generate("a", max_new_tokens=1) + reference = pipe.generate("a", config) assert generated == reference prompts = [ diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py index 9aa6931d85..d5df28bfd6 100644 --- a/tests/python_tests/test_sampling.py +++ b/tests/python_tests/test_sampling.py @@ -21,6 +21,8 @@ get_beam_search, get_beam_search_min_and_max_tokens, get_beam_search_with_single_stop_string, \ get_beam_search_with_multiple_stop_strings, get_beam_search_with_multiple_stop_strings_no_match, get_multinomial_max_and_min_token, \ get_multinomial_temperature_and_frequence_penalty, get_multinomial_temperature_and_presence_penalty, \ + get_greedy_stop_strings_exclude_from_output, get_greedy_stop_strings_include_to_output, \ + get_greedy_n_stop_strings_exclude_from_output, get_greedy_n_stop_strings_include_to_output, \ generate_and_compare_with_hf, get_multinomial_temperature_and_repetition_penalty, get_scheduler_config, \ run_continuous_batching @@ -77,7 +79,9 @@ def test_eos_greedy(tmp_path): @pytest.mark.precommit @pytest.mark.parametrize("generation_config", [get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_single_stop_string(), get_greedy_with_multiple_stop_strings(), get_greedy_with_multiple_stop_strings_no_match(), - get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), ], + get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), + get_greedy_stop_strings_exclude_from_output(), get_greedy_stop_strings_include_to_output(), + get_greedy_n_stop_strings_exclude_from_output(), get_greedy_n_stop_strings_include_to_output() ], ids=[ "greedy", "greedy_with_min_and_max_tokens", @@ -88,6 +92,10 @@ def test_eos_greedy(tmp_path): "beam", "beam_search_min_and_max_tokens", "beam_search_with_multiple_stop_strings_no_match", + "get_greedy_stop_strings_exclude_from_output", + "get_greedy_stop_strings_include_to_output", + "get_greedy_n_stop_strings_exclude_from_output", + "get_greedy_n_stop_strings_include_to_output" ]) def test_individual_generation_configs_deterministic(tmp_path, generation_config): prompts = [ diff --git a/tools/llm_bench/requirements.txt b/tools/llm_bench/requirements.txt index f5f4a3fdeb..acbc668c52 100644 --- a/tools/llm_bench/requirements.txt +++ b/tools/llm_bench/requirements.txt @@ -10,7 +10,7 @@ torch transformers>=4.40.0 diffusers>=0.22.0 #optimum is in dependency list of optimum-intel -git+https://github.com/huggingface/optimum-intel.git@main#egg=optimum-intel +git+https://github.com/huggingface/optimum-intel.git@420fa87d039425a906b7f755e4562b65947f016a#egg=optimum-intel git+https://github.com/openvinotoolkit/nncf.git@develop#egg=nncf packaging psutil