From fa36a48e5361d9d6b15b7096b56f3aafba48aba6 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Mon, 16 Dec 2024 17:58:12 +0000 Subject: [PATCH] Move beam search in case of chat scenario to sampler.cpp --- src/cpp/src/group_beam_searcher.cpp | 455 ------------------ src/cpp/src/llm_pipeline.cpp | 137 +++--- src/cpp/src/lm_encoding.cpp | 47 +- src/cpp/src/lm_encoding.hpp | 8 +- src/cpp/src/utils.hpp | 15 + .../src/visual_language/inputs_embedder.cpp | 64 ++- .../src/visual_language/inputs_embedder.hpp | 6 +- src/cpp/src/visual_language/pipeline.cpp | 13 +- 8 files changed, 163 insertions(+), 582 deletions(-) delete mode 100644 src/cpp/src/group_beam_searcher.cpp diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp deleted file mode 100644 index a0262c0dc8..0000000000 --- a/src/cpp/src/group_beam_searcher.cpp +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "openvino/genai/llm_pipeline.hpp" -#include "utils.hpp" -#include "lm_encoding.hpp" - -namespace { - -// Modified Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurrence in haystack -std::vector kmp_search(const std::vector& haystack, const std::vector& needle) { - if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token - return {haystack.begin(), haystack.end()}; - } - std::vector partial_match_table(needle.size() + 1, -1); - int cnd = 0; - for (size_t pos = 1; pos < needle.size(); ++pos) { - if (needle.at(pos) == needle.at(size_t(cnd))) { - partial_match_table.at(pos) = partial_match_table.at(size_t(cnd)); - } else { - partial_match_table.at(pos) = cnd; - while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) { - cnd = partial_match_table.at(size_t(cnd)); - } - } - ++cnd; - } - partial_match_table.back() = cnd; - std::vector res; - size_t haystack_id = 0; - int needle_id = 0; - while (haystack_id < haystack.size() - 1) { - if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) { - ++haystack_id; - ++needle_id; - if (needle_id == int(needle.size())) { - res.push_back(haystack.at(haystack_id)); - needle_id = partial_match_table.at(size_t(needle_id)); - } - } else { - needle_id = partial_match_table.at(size_t(needle_id)); - if (needle_id < 0) { - ++haystack_id; - ++needle_id; - } - } - } - return res; -} - -struct Token { - float log_prob; - int64_t idx; -}; - -std::vector log_softmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape().at(0) <= batch_idx) { - throw std::runtime_error("logits batch size doesn't match the number of beams"); - } - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; - size_t sequence_offset = (logits.get_shape().at(1) - 1) * vocab_size; - const float* beam_logits = logits.data() + batch_offset + sequence_offset; - float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size); - float log_sum = std::log( - std::accumulate(beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) { - return accumulated + std::exp(to_add - max_logit); - })); - std::vector tokens; - tokens.reserve(vocab_size); - for (size_t idx = 0; idx < vocab_size; ++idx) { - tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)}); - } - return tokens; -} - -struct Beam { - float score = -std::numeric_limits::infinity(); // The bigger, the better - std::vector tokens; - size_t global_beam_idx = 0; -}; - -bool greater(const Beam& left, const Beam& right) { - return left.score > right.score; -} - -struct Parameters { - std::vector> prompts; - int64_t eos_token_id; - size_t n_groups = 3; - size_t group_size = 5; - float diversity_penalty = 1.0; - size_t max_new_tokens = 20; - ov::genai::StopCriteria stop_criteria = ov::genai::StopCriteria::HEURISTIC; - float length_penalty = 1.0; - size_t no_repeat_ngram_size = std::numeric_limits::max(); - - std::function early_finish = [](const Beam&) { - return false; - }; -}; - -struct Group { - std::vector ongoing; // Best beams in front - std::vector min_heap; // The worst of the best completed beams is the first - bool done = false; - - void finish(Beam&& beam, const Parameters& parameters) { - beam.score /= std::pow(float(beam.tokens.size()), parameters.length_penalty); - - min_heap.push_back(std::move(beam)); - std::push_heap(min_heap.begin(), min_heap.end(), greater); - if (min_heap.size() > parameters.group_size) { - std::pop_heap(min_heap.begin(), min_heap.end(), greater); - min_heap.pop_back(); - } - } - void is_done(const Parameters& parameters) { - if (min_heap.size() < parameters.group_size) { - return; - } - size_t cur_len = ongoing.front().tokens.size(); - float best_sum_logprobs = ongoing.front().score; - float worst_score = min_heap.front().score; - switch (parameters.stop_criteria) { - case ov::genai::StopCriteria::EARLY: - done = true; - return; - case ov::genai::StopCriteria::HEURISTIC: { - float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - case ov::genai::StopCriteria::NEVER: { - size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len; - float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - default: - throw std::runtime_error("Never reached"); - } - } -}; - -// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search -// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values -// are used for next inference. select_next_tokens() returns empty, if all groups are completed -struct GroupBeamSearcher { - Parameters parameters; - std::vector> prompts_groups; - - GroupBeamSearcher(Parameters parameters) : parameters{parameters}, prompts_groups{parameters.prompts.size()} { - if (parameters.no_repeat_ngram_size == 0) { - throw std::runtime_error("no_repeat_ngram_size must be positive"); - } - for (std::vector& prompts_groups : prompts_groups) { - prompts_groups.resize(parameters.n_groups); - for (Group& group : prompts_groups) { - group.ongoing.resize(parameters.group_size); - group.ongoing.front().score = 0.0; - } - } - } - - std::pair, std::vector> select_next_tokens(const ov::Tensor& logits) { - std::vector next_tokens; - std::vector next_beams; - - const size_t promts_size = parameters.prompts.size(); - - next_tokens.reserve(promts_size * parameters.n_groups * parameters.group_size); - next_beams.reserve(promts_size * parameters.n_groups * parameters.group_size); - - size_t beam_count = 0; - size_t prompt_id = 0; - for (std::vector& groups : prompts_groups) { - for (Group& group : groups) { - if (group.done) { - continue; - } - for (Beam& beam : group.ongoing) { - // beam.tokens.empty() holds for the first select_next_tokens() call. - // Every beam is constructed from the single batch at first call - if (beam.tokens.empty()) { - beam.global_beam_idx = prompt_id; - } else { - beam.global_beam_idx = beam_count; - ++beam_count; - } - } - } - - prompt_id += 1; - } - - for (int prompt_id = 0; prompt_id < promts_size; prompt_id++) { - const std::vector prompt = parameters.prompts[prompt_id]; - std::vector& groups = prompts_groups[prompt_id]; - auto [prompt_next_tokens, prompt_next_beams] = select_prompt_next_tokens(logits, prompt, groups); - - next_tokens.insert(next_tokens.end(), prompt_next_tokens.begin(), prompt_next_tokens.end()); - next_beams.insert(next_beams.end(), prompt_next_beams.begin(), prompt_next_beams.end()); - } - - return {next_tokens, next_beams}; - } - - std::pair, std::vector> select_prompt_next_tokens(const ov::Tensor& logits, - const std::vector& prompt, - std::vector& groups) { - std::vector next_tokens; - std::vector next_beams; - next_tokens.reserve(parameters.n_groups * parameters.group_size); - next_beams.reserve(parameters.n_groups * parameters.group_size); - - for (auto group = groups.begin(); group != groups.end(); ++group) { - if (group->done) { - continue; - } - std::vector candidates; - candidates.reserve(parameters.group_size * 2 * parameters.group_size); - for (const Beam& beam : group->ongoing) { - std::vector tokens = log_softmax(logits, beam.global_beam_idx); - for (auto prev_group = groups.cbegin(); prev_group != group; ++prev_group) { - for (const Beam& prev_beam : prev_group->ongoing) { - if (prev_beam.tokens.size() > beam.tokens.size()) { - tokens.at(size_t(prev_beam.tokens.back())).log_prob -= parameters.diversity_penalty; - } - } - } - std::vector full_text{prompt}; - full_text.insert(full_text.end(), beam.tokens.begin(), beam.tokens.end()); - if (full_text.size() > 1 && full_text.size() >= parameters.no_repeat_ngram_size) { - auto tail_start = full_text.end() - ptrdiff_t(parameters.no_repeat_ngram_size) + 1; - for (int64_t banned_token : kmp_search(full_text, {tail_start, full_text.end()})) { - tokens.at(size_t(banned_token)).log_prob = -std::numeric_limits::infinity(); - } - } - std::sort(tokens.begin(), tokens.end(), [](Token left, Token right) { - return left.log_prob > right.log_prob; // Most probable tokens in front - }); - size_t add_count = 0; - for (Token token : tokens) { - Beam new_candidate = beam; - new_candidate.score += token.log_prob; - new_candidate.tokens.push_back(token.idx); - if (parameters.early_finish(new_candidate)) { - group->finish(std::move(new_candidate), parameters); - } else { - candidates.push_back(std::move(new_candidate)); - ++add_count; - if (add_count == 2 * parameters.group_size) { - break; - } - } - } - } - // Sample 2 * group_size highest score tokens to get at least 1 non EOS token per beam - if (candidates.size() < 2 * parameters.group_size) { - throw std::runtime_error("No beams left to search"); - } - auto to_sort = candidates.begin() + ptrdiff_t(2 * parameters.group_size); - std::partial_sort(candidates.begin(), to_sort, candidates.end(), greater); - group->ongoing.clear(); - for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) { - if (parameters.eos_token_id == candidates.at(cand_idx).tokens.back()) { - // If beam_token does not belong to top num_beams tokens, it should not be added - if (cand_idx >= parameters.group_size) { - continue; - } - group->finish(std::move(candidates.at(cand_idx)), parameters); - } else { - group->ongoing.push_back(std::move(candidates.at(cand_idx))); - if (group->ongoing.size() == parameters.group_size) { - break; - } - } - } - group->is_done(parameters); - if (!group->done) { - for (const Beam& beam : group->ongoing) { - next_tokens.push_back(beam.tokens.back()); - next_beams.push_back(int32_t(beam.global_beam_idx)); - } - } - } - return {next_tokens, next_beams}; - } -}; - -// Consume group_beam_searcher because beams are consumed -std::vector>> finalize(GroupBeamSearcher&& group_beam_searcher) { - std::vector>> finalized; - finalized.resize(group_beam_searcher.prompts_groups.size()); - - for (size_t prompt_id = 0; prompt_id < group_beam_searcher.prompts_groups.size(); prompt_id++) { - std::vector& groups = group_beam_searcher.prompts_groups.at(prompt_id); - finalized.at(prompt_id).reserve(groups.size()); - - for (Group& group : groups) { - if (!group.done) { - for (Beam& beam : group.ongoing) { - group.finish(std::move(beam), group_beam_searcher.parameters); - } - } - finalized.at(prompt_id).push_back(std::move(group.min_heap)); - } - } - - return finalized; -} - -void reset_all_inputs_to_empty_tensors(ov::InferRequest& request) { - request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0})); - request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0})); - if (request.get_compiled_model().inputs().size() == 4) - request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0})); -} -} // namespace - -namespace ov { -namespace genai { - -std::pair beam_search(ov::InferRequest& lm, - ov::Tensor input_ids, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx) { - OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, - "number of beams should be divisible by number of groups"); - - auto batch_size = input_ids.get_shape().at(0); - auto sequence_length = input_ids.get_shape().at(1); - - // Initialize beam search. - const int64_t* prompt_data = input_ids.data(); - std::vector> prompts; - prompts.reserve(batch_size); - for (size_t batch = 0; batch < batch_size; batch++) { - size_t batch_offset = batch * sequence_length; - const int64_t* prompt_start = prompt_data + batch_offset; - prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); - } - - lm.set_tensor("input_ids", input_ids); - lm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) - lm.set_tensor("position_ids", *position_ids); - - ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); - lm.set_tensor("beam_idx", beam_idx); - - Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); - parameters.eos_token_id = config.eos_token_id; - parameters.n_groups = config.num_beam_groups; - parameters.group_size = config.num_beams / config.num_beam_groups; - parameters.diversity_penalty = config.diversity_penalty; - parameters.length_penalty = config.length_penalty; - parameters.stop_criteria = config.stop_criteria; - parameters.no_repeat_ngram_size = config.no_repeat_ngram_size; - GroupBeamSearcher group_beam_searcher{parameters}; - - std::vector next_tokens; - std::vector next_beams; - - // Reserve for performance counters. - std::vector new_token_times; - std::vector batch_sizes; - new_token_times.reserve(parameters.max_new_tokens); - batch_sizes.reserve(parameters.max_new_tokens); - - for (size_t length_count = 0; ; ++length_count) { - lm.infer(); - - std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); - new_token_times.emplace_back(std::chrono::steady_clock::now()); - batch_sizes.emplace_back(batch_size); - - if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { - // Break the cycle before masks are extended in update_attention_mask_with_beams. - // If generation is continued, attention_mask length should be equal to KV cache size. - break; - } - - size_t running_batch_size = next_tokens.size(); - // Set pointers - lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {running_batch_size, 1}, next_tokens.data()}); - lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {running_batch_size}, next_beams.data()}); - - // Set auxiliary inputs - update_attention_mask_with_beams(lm.get_tensor("attention_mask"), next_beams); - if (position_ids.has_value()) - update_position_ids(lm.get_tensor("position_ids"), lm.get_tensor("attention_mask")); - } - - reset_all_inputs_to_empty_tensors(lm); - - auto scores_comparator = [](Beam& left, Beam& right) { - return (left.score > right.score); - }; - - auto result = finalize(std::move(group_beam_searcher)); - ov::genai::EncodedResults results; - int32_t res_selected_beam_idx = 0; - results.scores.reserve(config.num_return_sequences * result.size()); - results.tokens.reserve(config.num_return_sequences * result.size()); - auto& raw_perf_counters = results.perf_metrics.raw_metrics; - raw_perf_counters.m_new_token_times = new_token_times; - raw_perf_counters.m_batch_sizes = batch_sizes; - - // align output with HF - for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { - auto prompt_group = result.at(prompt_id); - std::vector> plain_beams; - plain_beams.reserve(parameters.n_groups * parameters.group_size); - for (std::vector& group : prompt_group) { - for (Beam& beam : group) { - plain_beams.push_back(beam); - } - } - assert(config.num_return_sequences <= plain_beams.size()); - std::partial_sort( - plain_beams.begin(), - plain_beams.begin() + config.num_return_sequences, - plain_beams.end(), - scores_comparator - ); - res_selected_beam_idx = plain_beams.at(0).get().global_beam_idx; - for ( - auto beam = plain_beams.begin(); - beam != plain_beams.begin() + config.num_return_sequences; - ++beam - ) { - results.scores.push_back(beam->get().score); - results.tokens.push_back(std::move(beam->get().tokens)); - } - } - - return {results, res_selected_beam_idx}; -} - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 623333e349..f57f31baa3 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -24,28 +24,23 @@ namespace ov { namespace genai { -std::pair beam_search( - ov::InferRequest& lm, - ov::Tensor prompts, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx -); - class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; bool is_chat_conversation = false; bool m_trust_encoded_history = true; - std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; std::vector m_tokenized_chat_history; ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; - size_t m_to_remove_from_hist = 0; size_t m_kv_cache_seq_length_axis = 2; Sampler m_sampler; + // Tail of previous output in chat mode is missing in KV cache, let's keep it + std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; StatefulLLMPipeline( const ov::InferRequest& request, @@ -154,35 +149,44 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; + size_t trusted_history_length = 0; if (!m_tokenized_chat_history.empty()) { std::set stop_tokens = config.stop_token_ids; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); - m_trust_encoded_history = last_same_hist_token == SIZE_MAX; + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + m_trust_encoded_history = trusted_history_length == SIZE_MAX; } if (m_tokenized_chat_history.empty()) { encoded_input = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.input_ids.data() + last_same_hist_token); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.input_ids.data() + trusted_history_length); ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape()); std::fill_n(new_attention_mask.data(), new_tensor.get_shape()[1], 1); encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), - {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}); + {1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length}); new_tensor.copy_to(encoded_input.input_ids); encoded_input.attention_mask = new_attention_mask; - - m_selected_beam = std::nullopt; + m_last_disappeared_token = std::nullopt; } else { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; + m_tokenized_chat_history.clear(); m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size()); std::copy_n(new_chat_tokens.input_ids.data(), new_chat_tokens.input_ids.get_size(), @@ -264,6 +268,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + // Tail of previous output in chat mode is missing in KV cache. + if (m_last_disappeared_token.has_value()) { + attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1); + input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token); + } + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; // If eos_token_id was not provided, take value from default m_generation_config @@ -294,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); - ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller); + ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; @@ -304,10 +314,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // Between subsequent runs attention_mask should not be modified. auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist; + + kv_cache_len = atten_mask_history.get_shape()[1] - m_kv_history_manager.num_tokens_to_remove_from_kv_cache; ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; - auto start_atten_hst = atten_mask_history.data() + kv_cache_len * (*m_selected_beam); + auto start_atten_hst = atten_mask_history.data(); + std::copy(start_atten_hst, start_atten_hst + kv_cache_len, new_atten_mask.data()); std::copy(attention_mask.data(), attention_mask.data() + prompt_len, @@ -317,6 +329,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { concatenated_attention_mask = attention_mask; } + size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1]; + bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { @@ -330,51 +344,58 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation && !m_trust_encoded_history) { m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); } - ov::genai::EncodedResults result; - if (config.is_beam_search() && is_chat_conversation) { - std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask, - config, position_ids, m_selected_beam); - } else { - std::vector requests; - size_t block_size = 1; - bool enable_prefix_caching = false; - - for (size_t request_id = 0; request_id < batch_size; request_id++) { - SequenceGroup::Ptr sequence_group; - if (is_chat_conversation) { - ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); - sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); - } else { - size_t seq_len = input_ids.get_shape().at(1); - size_t batch_offset = request_id * seq_len; - const int64_t* prompt_start = input_ids.data() + batch_offset; - std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); + std::vector requests; + size_t block_size = 1; + bool enable_prefix_caching = false; - sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); - } + for (size_t request_id = 0; request_id < batch_size; request_id++) { + SequenceGroup::Ptr sequence_group; + if (is_chat_conversation) { + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); + } else { + size_t seq_len = input_ids.get_shape().at(1); + size_t batch_offset = request_id * seq_len; + const int64_t* prompt_start = input_ids.data() + batch_offset; + std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - sequence_group->set_sequence_group_ptr(sequence_group); - requests.push_back(sequence_group); + sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); } - if (m_sampler.get_seed() != config.rng_seed) { - m_sampler.set_seed(config.rng_seed); - } + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); + } - std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, - m_sampler, requests, position_ids, std::nullopt, m_selected_beam); + if (m_sampler.get_seed() != config.rng_seed) { + m_sampler.set_seed(config.rng_seed); } + ov::genai::EncodedResults result = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, + streamer_ptr, m_sampler, requests, position_ids, std::nullopt); if (is_chat_conversation) { + // force remove from kv_cache last answer + if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { + m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; + } + + // There's only one request in chat mode + if (requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || requests[0]->handle_dropped()) + m_last_disappeared_token = result.tokens[0].back(); + else + m_last_disappeared_token = std::nullopt; + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); } else { reset_kv_state(); - m_selected_beam = std::nullopt; } + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -388,10 +409,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history = {}; @@ -409,10 +430,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; - m_selected_beam = std::nullopt; m_trust_encoded_history = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + m_last_disappeared_token = std::nullopt; if (!m_tokenized_chat_history.empty()) { reset_kv_state(); m_history.clear(); diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 8ef993e09f..9afdb6b3a0 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -9,12 +9,11 @@ #include #include +#include "utils.hpp" +#include "debug_utils.hpp" #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "debug_utils.hpp" - -#include "utils.hpp" namespace ov { namespace genai { @@ -51,7 +50,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector get_lm_encoded_results( +EncodedResults get_lm_encoded_results( ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, @@ -59,8 +58,7 @@ std::pair get_lm_encoded_results( Sampler& sampler, std::vector sequence_groups, std::optional position_ids, - std::optional m_embedding, - std::optional selected_beam_idx + std::optional m_embedding ) { std::vector generations; for (SequenceGroup::Ptr sequence_group : sequence_groups) { @@ -105,11 +103,10 @@ std::pair get_lm_encoded_results( m_llm.set_tensor("position_ids", *position_ids); ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - std::fill_n(beam_idx.data(), batch_size, selected_beam_idx.has_value() ? *selected_beam_idx : 0); + std::fill_n(beam_idx.data(), batch_size, 0); m_llm.set_tensor("beam_idx", beam_idx); // "Prompt" phase - const auto infer_start = std::chrono::steady_clock::now(); m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); @@ -171,13 +168,13 @@ std::pair get_lm_encoded_results( // apply strides to shift to a next sequence input_ids_data += num_scheduled_tokens; - // for different sequences iteration of beams started from 0, but we collect it to one input_ids# + // for different sequences iteration of beams started from 0, but we collect it to one input_ids next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id())); } } - for (size_t i = 0; i < sequence_groups.size(); i++) { - beam_offets[sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]); + for (size_t i = 0; i < active_sequence_groups.size(); i++) { + beam_offets[active_sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]); } if (m_embedding.has_value()) { @@ -212,30 +209,22 @@ std::pair get_lm_encoded_results( streamer_ptr->end(); } - // Collect results - - size_t next_selected_beam = 0; - for (size_t i = 0; i < sequence_groups.size(); i++) { - auto request = sequence_groups[i]; - std::vector generation_outputs; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, sequences.size()); - - for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) { - const auto & sequence = sequences[seq_id]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs(); - - results.tokens.push_back(sequence->get_generated_ids()); - results.scores.push_back(score); + for (auto& sequence_group : sequence_groups) { + // sequences is sorted by cumulative_log_prob with length_penalty + auto outputs = sequence_group->get_finished_sequences(); + + auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size()); + for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) { + const auto& output = outputs[output_idx]; + results.tokens.push_back(output->get_generated_ids()); + results.scores.push_back(output->get_beam_search_score(sequence_group->get_sampling_parameters())); } - // next_selected_beam = sampler.last_selected_beam(request); } for (SequenceGroup::Ptr sequence_group : sequence_groups) sampler.clear_request_info(sequence_group->get_request_id()); - return {results, next_selected_beam}; + return results; } } // namespace genai diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index fa6692ede0..0a342f0a37 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,13 +8,9 @@ namespace ov { namespace genai { -std::pair get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, +EncodedResults get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, - std::optional position_ids, std::optional m_embedding, std::optional selected_beam_idx); - -void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams); - -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask); + std::optional position_ids, std::optional m_embedding); } } diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 96191387cd..57225e60ff 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -28,6 +28,21 @@ enum class GenerationChatInputsType { ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs }; +struct HistoryRemoveManager +{ + size_t num_tokens_to_remove_from_kv_cache = 0; + size_t trusted_history_length = 0; + + bool does_kv_cache_need_to_update() { + return (trusted_history_length > 0 || num_tokens_to_remove_from_kv_cache > 0); + } + + void reset() { + num_tokens_to_remove_from_kv_cache = 0; + trusted_history_length = 0; + } +}; + Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index 8175d44b16..68cda50cdb 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -42,11 +42,12 @@ class InputsEmbedder::IInputsEmbedder { std::string m_templated_chat_history; // Tokenized chat history std::vector m_tokenized_history; - // The number of elements, which need to remove from the end of KV cache - // removed elements will be added to inputs_ids - size_t m_to_remove_from_hist = 0; // Tail of previous output for LM in chat mode is missing in KV cache. std::optional m_last_disappeared_token = std::nullopt; + // If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache + // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history + // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; public: virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0; @@ -63,22 +64,29 @@ class InputsEmbedder::IInputsEmbedder { return m_tokenized_history; } - size_t get_amount_to_remove_from_hist() const { - return m_to_remove_from_hist; + size_t get_num_tokens_to_remove_from_hist() const { + return m_kv_history_manager.num_tokens_to_remove_from_kv_cache; } - void update_tokenized_history(std::vector encoded_result, bool token_will_disappear) { - std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history)); - m_to_remove_from_hist = 0; - if (token_will_disappear) + void update_tokenized_history(const std::vector& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len) { + if (is_beam_search) { + m_kv_history_manager.trusted_history_length = m_tokenized_history.size(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len; + } else { + m_kv_history_manager.reset(); + } + + if (is_last_token_disappear) m_last_disappeared_token = encoded_result.back(); else m_last_disappeared_token = std::nullopt; + + std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history)); } virtual void start_chat(const std::string& system_message) { m_is_chat_conversation = true; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); if (!m_tokenized_history.empty()) { m_history.clear(); m_templated_chat_history.clear(); @@ -101,7 +109,7 @@ class InputsEmbedder::IInputsEmbedder { virtual void finish_chat() { m_is_chat_conversation = false; - m_to_remove_from_hist = 0; + m_kv_history_manager.reset(); m_history.clear(); m_templated_chat_history.clear(); @@ -171,24 +179,32 @@ class InputsEmbedder::IInputsEmbedder { // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history // so let's check it out, find the trusted part and use it in on the next step - size_t last_same_hist_token = 0; + size_t trusted_history_length = 0; if (!m_tokenized_history.empty()) { std::set stop_tokens = {m_tokenizer.get_eos_token_id()}; - last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); } if (m_tokenized_history.empty()) { encoded_input_ids = new_chat_tokens; - } else if (last_same_hist_token != SIZE_MAX) { - m_to_remove_from_hist = m_tokenized_history.size() - last_same_hist_token; - // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it - m_to_remove_from_hist -= m_last_disappeared_token.has_value() ? 1 : 0; + + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { + // does_kv_cache_need_to_update will be true here if beam search is activated + // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly + // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager + if (m_kv_history_manager.does_kv_cache_need_to_update()) { + trusted_history_length = m_kv_history_manager.trusted_history_length; + } else { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length; + // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it + m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(), - {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}, - new_chat_tokens.data() + last_same_hist_token); + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}, + new_chat_tokens.data() + trusted_history_length); encoded_input_ids = ov::Tensor(new_chat_tokens.get_element_type(), - {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}); + {1, new_chat_tokens.get_shape().at(1) - trusted_history_length}); new_tensor.copy_to(encoded_input_ids); } else { encoded_input_ids = utils::subtract_chat_tokenized_inputs( @@ -1192,12 +1208,12 @@ std::vector InputsEmbedder::get_tokenized_history() const { return m_impl->get_tokenized_history(); } -void InputsEmbedder::update_tokenized_history(std::vector encoded_result, bool token_will_disappear) { - return m_impl->update_tokenized_history(encoded_result, token_will_disappear); +void InputsEmbedder::update_tokenized_history(const std::vector& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len) { + return m_impl->update_tokenized_history(encoded_result, is_last_token_disappear, is_beam_search, last_answer_len); } -size_t InputsEmbedder::get_amount_to_remove_from_hist() const { - return m_impl->get_amount_to_remove_from_hist(); +size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const { + return m_impl->get_num_tokens_to_remove_from_hist(); } Tokenizer InputsEmbedder::get_tokenizer() const { diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 8c84c6ad43..967c01ff4f 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -43,11 +43,11 @@ class InputsEmbedder { // returns tokenized chat history std::vector get_tokenized_history() const; - // add new results to tokenized chat history - void update_tokenized_history(std::vector encoded_result, bool token_will_disappear); + // add new results to tokenized history + void update_tokenized_history(const std::vector& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len); // returns amount of elements, which need to remove from the end of the KV cache - size_t get_amount_to_remove_from_hist() const; + size_t get_num_tokens_to_remove_from_hist() const; // starts chat and adds optional system_message to chat history void start_chat(const std::string& system_message); diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index ad4529e22f..2731eb20ff 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -169,7 +169,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); auto end_get_inputs_embeds = std::chrono::steady_clock::now(); - auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist(); + auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist(); ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt); std::vector requests; @@ -217,10 +217,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_sampler.set_seed(generation_config.rng_seed); } - ov::genai::EncodedResults encoded_result; - int32_t m_selected_beam = 0; - std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, - position_ids, m_embedding, std::nullopt); + ov::genai::EncodedResults encoded_result = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, + position_ids, m_embedding); auto decode_start_time = std::chrono::steady_clock::now(); VLMDecodedResults decoded; @@ -230,6 +228,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } auto decode_end_time = std::chrono::steady_clock::now(); + m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH, + generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size)); + std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) { m_inputs_embedder->update_chat_history(decoded_results); @@ -256,8 +257,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { decoded.perf_metrics.m_evaluated = false; decoded.perf_metrics.evaluate_statistics(generate_start_time); - m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH); - return decoded; }