diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 84f76730eb..f663b27dd9 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -36,13 +36,15 @@ std::pair beam_search( class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; - bool is_chat_conversation = false; - bool m_is_cache_empty = true; + bool m_trust_encoded_history = true; std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; - TokenizedInputs m_tokenized_chat_history; + std::vector m_tokenized_chat_history; + ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + size_t m_to_remove_from_hist = 0; + size_t m_kv_cache_seq_length_axis = 2; StatefulLLMPipeline( const ov::InferRequest& request, @@ -77,6 +79,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { ov::Core core; auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config); utils::slice_matmul_statefull_model(model); + m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model); if (auto filtered_plugin_config = extract_adapters_from_properties(plugin_config, &m_generation_config.adapters)) { m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model."); @@ -102,8 +105,20 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF) + m_chat_input_type = ov::genai::utils::GenerationChatInputsType::STRING; + + if (is_chat_conversation) + OPENVINO_ASSERT(m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS, + "Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat."); + auto start_time = std::chrono::steady_clock::now(); GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; + // If eos_token_id was not provided, take value from default m_generation_config + if (config.eos_token_id == -1) + config.set_eos_token_id(m_generation_config.eos_token_id); + config.validate(); + TokenizedInputs encoded_input; if (auto input_vector = std::get_if>(&inputs)) { @@ -127,19 +142,51 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); // Do not add special tokens in chat scenario to be aligned with HF. auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)); - if (m_is_cache_empty) { + auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); + + // some symbols combinations can be encoded by the tokenizer in different ways + // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history + // so let's check it out, find the trusted part and use it in on the next step + size_t last_same_hist_token = 0; + if (!m_tokenized_chat_history.empty()) { + std::set stop_tokens = config.stop_token_ids; + last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + m_trust_encoded_history = last_same_hist_token == SIZE_MAX; + } + + if (m_tokenized_chat_history.empty()) { encoded_input = new_chat_tokens; + } else if (last_same_hist_token != SIZE_MAX) { + m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + + ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), + {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}, + new_chat_tokens.input_ids.data() + last_same_hist_token); + + ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape()); + std::fill_n(new_attention_mask.data(), new_tensor.get_shape()[1], 1); + + encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), + {1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token}); + new_tensor.copy_to(encoded_input.input_ids); + encoded_input.attention_mask = new_attention_mask; + + m_selected_beam = std::nullopt; } else { - auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; - m_tokenized_chat_history = new_chat_tokens; + m_tokenized_chat_history.clear(); + m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size()); + std::copy_n(new_chat_tokens.input_ids.data(), new_chat_tokens.input_ids.get_size(), + std::back_inserter(m_tokenized_chat_history)); + // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { encoded_input = m_tokenizer.encode(prompt); } } + auto encode_stop_time = std::chrono::steady_clock::now(); auto encoded_results = generate(encoded_input, config, streamer); @@ -188,6 +235,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF) + m_chat_input_type = ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS; + + if (is_chat_conversation) + // if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role + OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user", + "Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat."); + auto start_time = std::chrono::steady_clock::now(); ov::Tensor input_ids; ov::Tensor attention_mask; @@ -199,6 +254,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { attention_mask = data->attention_mask; } + if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) + std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; // If eos_token_id was not provided, take value from default m_generation_config @@ -230,16 +288,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); + ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; - if (is_chat_conversation && !m_is_cache_empty) { + if (is_chat_conversation && !m_tokenized_chat_history.empty()) { OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); // If history is saved in KV cache, concatenate new attention_mask with the already existing. // Between subsequent runs attention_mask should not be modified. auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1]; + kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist; ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; auto start_atten_hst = atten_mask_history.data() + kv_cache_len * (*m_selected_beam); @@ -263,6 +322,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_adapter_controller->apply(m_model_runner, config.adapters); } + if (is_chat_conversation && !m_trust_encoded_history) { + m_trust_encoded_history = true; + m_to_remove_from_hist = 0; + } + ov::genai::EncodedResults result; if (config.is_beam_search() && is_chat_conversation) { std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask, @@ -274,8 +338,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { for (size_t request_id = 0; request_id < batch_size; request_id++) { SequenceGroup::Ptr sequence_group; - if (is_chat_conversation && !m_is_cache_empty) { - sequence_group = std::make_shared(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching); + if (is_chat_conversation) { + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); } else { size_t seq_len = input_ids.get_shape().at(1); size_t batch_offset = request_id * seq_len; @@ -294,12 +359,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { sampler, requests, position_ids, std::nullopt, m_selected_beam); } - if (!is_chat_conversation) { + if (is_chat_conversation) { + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + } else { reset_kv_state(); m_selected_beam = std::nullopt; - } else { - m_is_cache_empty = false; } + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -313,12 +379,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; - if (!m_is_cache_empty) { + m_selected_beam = std::nullopt; + m_trust_encoded_history = true; + m_to_remove_from_hist = 0; + m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + if (!m_tokenized_chat_history.empty()) { reset_kv_state(); - m_is_cache_empty = true; m_history = {}; m_templated_chat_history = ""; + m_tokenized_chat_history.clear(); } if (system_message.empty()) return; @@ -332,11 +401,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; m_selected_beam = std::nullopt; - if (!m_is_cache_empty) { + m_trust_encoded_history = true; + m_to_remove_from_hist = 0; + m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; + if (!m_tokenized_chat_history.empty()) { reset_kv_state(); - m_is_cache_empty = true; m_history.clear(); m_templated_chat_history.clear(); + m_tokenized_chat_history.clear(); } } }; diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 337b0ab47e..3690920295 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -13,6 +13,8 @@ #include "openvino/op/tanh.hpp" #include "openvino/op/transpose.hpp" +#include "sampler.hpp" + namespace ov { namespace genai { namespace utils { @@ -306,6 +308,79 @@ ov::Core singleton_core() { return core; } +size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector tokenized_history, std::set stop_tokens) { + size_t idx = 0; + auto encoded_history_data = encoded_history.data(); + while(idx < encoded_history.get_size() && idx < tokenized_history.size()) { + if (encoded_history_data[idx] != tokenized_history[idx]) + break; + idx++; + } + + // encoded_history after decode of tokenizer could lose one last token (eos/stop token) + if ((idx == tokenized_history.size() && idx == encoded_history.get_size()) || + (encoded_history.get_size() < tokenized_history.size() && idx == tokenized_history.size() - 1 && stop_tokens.find(tokenized_history.back()) != stop_tokens.end())) + return SIZE_MAX; + else + return idx; +} + +size_t get_seq_len_axis(std::shared_ptr model) { + // sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size], + // therefore usually seq_length_axis = 2 + size_t seq_length_axis = 2; + + // "ReadValue" node is KV cache representation in stateful model + std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); + + for (const auto op : model->get_ops()) { + // check input size, as in LoRA adapters case it could be 0 + if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) { + continue; + } + + // Shape example: [-1,4,0,64] + auto shape = op->get_input_partial_shape(0); + + for (size_t i = 0; i < shape.rank().get_length(); i++) { + // Find axis = 0. This would be sequence length axis. + if (shape[i] == 0) { + seq_length_axis = i; + } + } + break; + } + + return seq_length_axis; +} + +void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional adapter_controller) { + // nothing to trim in this case + if (remove_from_end == 0) + return; + + auto states = request.query_state(); + for (auto& state : states) { + if(adapter_controller && adapter_controller->has_state_name(state.get_name())) + continue; + + ov::Tensor old_tensor = state.get_state(); + // [BATCH_SIZE, num_kv_heads, seq_len, head_size] + auto shape = old_tensor.get_shape(); + shape[seq_length_axis] -= remove_from_end; + + ov::Coordinate new_shape_begin{0, 0, 0, 0}; + ov::Coordinate new_shape_end{shape}; + + auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end); + + ov::Tensor new_tensor(old_tensor.get_element_type(), shape); + trimmed_tensor.copy_to(new_tensor); + + state.set_state(new_tensor); + } +} + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 792987d383..57728cd0dc 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -22,6 +22,11 @@ constexpr bool is_container().begin()), decltype(std::declval().end())>> = true; +enum class GenerationChatInputsType { + UNDEF = 0, // Default value, type of inputs is not defined + STRING = 1, // Type of inputs is StringInputs + ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs +}; Tensor init_attention_mask(const Tensor& position_ids); @@ -93,6 +98,12 @@ ov::Core singleton_core(); template void read_rt_info(std::shared_ptr& model, const char* name, T& value); +size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector tokenized_history, std::set stop_tokens); + +size_t get_seq_len_axis(std::shared_ptr model); + +void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional adapter_controller); + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index ced17a2ebd..dfdb1521ef 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -39,8 +39,11 @@ class InputsEmbedder::IInputsEmbedder { ChatHistory m_history; // Templated chat history std::string m_templated_chat_history; - // Whether we have computed some inputs already - bool m_is_cache_empty = true; + // Tokenized chat history + std::vector m_tokenized_chat_history; + // The number of elements, which need to remove from the end of KV cache + // removed elements will be added to inputs_ids + size_t m_to_remove_from_hist = 0; public: virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0; @@ -53,12 +56,26 @@ class InputsEmbedder::IInputsEmbedder { return m_tokenizer; } + std::vector get_tokenized_chat_history() const { + return m_tokenized_chat_history; + } + + size_t get_amount_to_remove_from_hist() const { + return m_to_remove_from_hist; + } + + void update_tokenized_chat_history(std::vector encoded_result) { + std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_chat_history)); + m_to_remove_from_hist = 0; + } + virtual void start_chat(const std::string& system_message) { m_is_chat_conversation = true; - if (!m_is_cache_empty) { + m_to_remove_from_hist = 0; + if (!m_tokenized_chat_history.empty()) { m_history.clear(); m_templated_chat_history.clear(); - m_is_cache_empty = true; + m_tokenized_chat_history.clear(); } if (system_message.empty()) { return; @@ -77,10 +94,11 @@ class InputsEmbedder::IInputsEmbedder { virtual void finish_chat() { m_is_chat_conversation = false; - m_is_cache_empty = true; + m_to_remove_from_hist = 0; m_history.clear(); m_templated_chat_history.clear(); + m_tokenized_chat_history.clear(); } protected: @@ -92,7 +110,7 @@ class InputsEmbedder::IInputsEmbedder { m_vlm_config{vlm_config}, m_vision_encoder(model_dir, m_vlm_config.model_type, device, device_config), m_embedding(model_dir, m_vlm_config.scale_emb, device, device_config), - m_tokenizer{model_dir.string(), device_config} { } + m_tokenizer{model_dir, device_config} { } IInputsEmbedder( const VLMConfig& vlm_config, @@ -140,15 +158,28 @@ class InputsEmbedder::IInputsEmbedder { new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt, chat_template_fallback); } auto start_tokenizer_time = std::chrono::steady_clock::now(); - ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; - if (m_is_cache_empty) { + ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids; + TokenizedInputs prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); + + // some symbols combinations can be encoded by the tokenizer in different ways + // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history + // so let's check it out, find the trusted part and use it in on the next step + size_t last_same_hist_token = 0; + if (!m_tokenized_chat_history.empty()) { + std::set stop_tokens = {m_tokenizer.get_eos_token_id()}; + last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); + } + + if (m_tokenized_chat_history.empty()) { encoded_input_ids = new_chat_tokens; - // after first `get_inputs_embeds` is called, we supposed LLM is inferred and cache is not empty - m_is_cache_empty = false; + } else if (last_same_hist_token != SIZE_MAX) { + m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token; + + ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(), + {1, new_chat_tokens.get_shape().at(1) - last_same_hist_token}, + new_chat_tokens.data() + last_same_hist_token); + encoded_input_ids = new_tensor; } else { - TokenizedInputs prev_chat_tokens = m_tokenizer.encode( - m_templated_chat_history - ); encoded_input_ids = utils::subtract_chat_tokenized_inputs( {new_chat_tokens}, prev_chat_tokens ).input_ids; @@ -156,6 +187,9 @@ class InputsEmbedder::IInputsEmbedder { auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); m_templated_chat_history = std::move(new_templated_chat_history); + m_tokenized_chat_history.clear(); + std::copy(new_chat_tokens.data(), new_chat_tokens.data() + new_chat_tokens.get_size(), + std::back_inserter(m_tokenized_chat_history)); } else { auto start_tokenizer_time = std::chrono::steady_clock::now(); encoded_input_ids = m_tokenizer.encode(prompt).input_ids; @@ -639,7 +673,6 @@ class InputsEmbedderLLaVA : public InputsEmbedder::IInputsEmbedder { merged_idx++; } } - return merged_embeds; } }; @@ -1138,6 +1171,18 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const { return m_impl->get_embedding_model(); } +std::vector InputsEmbedder::get_tokenized_chat_history() const { + return m_impl->get_tokenized_chat_history(); +} + +void InputsEmbedder::update_tokenized_chat_history(std::vector encoded_result) { + return m_impl->update_tokenized_chat_history(encoded_result); +} + +size_t InputsEmbedder::get_amount_to_remove_from_hist() const { + return m_impl->get_amount_to_remove_from_hist(); +} + Tokenizer InputsEmbedder::get_tokenizer() const { return m_impl->get_tokenizer(); } diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 0e3a3533e2..5c5b9d2b81 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -40,6 +40,13 @@ class InputsEmbedder { // returns tokenizer Tokenizer get_tokenizer() const; + // returns tokenized chat history + std::vector get_tokenized_chat_history() const; + // add new results to tokenized chat history + void update_tokenized_chat_history(std::vector encoded_result); + // returns amount of elements, which need to remove from the end of the KV cache + size_t get_amount_to_remove_from_hist() const; + // starts chat and adds optional system_message to chat history void start_chat(const std::string& system_message); // adds currently generated text to chat history diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index f7508acb35..b8e89a8e04 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -64,6 +64,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { std::shared_ptr m_inputs_embedder; // Load pipeline time float m_load_time_ms = 0; + // Axis num in kv cache from m_language model, which contains information about history len + size_t m_kv_cache_seq_length_axis = 2; VLMPipelineImpl( const std::filesystem::path& models_dir, @@ -87,9 +89,14 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_tokenizer = m_inputs_embedder->get_tokenizer(); m_embedding = m_inputs_embedder->get_embedding_model(); - m_language = utils::singleton_core().compile_model( + auto compiled_language_model = utils::singleton_core().compile_model( models_dir / "openvino_language_model.xml", device, properties - ).create_infer_request(); + ); + + auto language_model = compiled_language_model.get_runtime_model(); + m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(language_model); + + m_language = compiled_language_model.create_infer_request(); m_language.get_tensor("attention_mask").set_shape({1, 0}); @@ -153,14 +160,20 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); auto end_get_inputs_embeds = std::chrono::steady_clock::now(); + auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist(); + ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt); + Sampler sampler = Sampler(m_tokenizer); std::vector requests; size_t request_id = 0; size_t block_size = 1; // not used bool enable_prefix_caching = false; - size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1); + + auto tokenized_chat_history = m_inputs_embedder->get_tokenized_chat_history(); + size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - to_remove_from_hist; size_t inputs_embeds_size = inputs_embeds.get_shape().at(1); + ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size }); std::fill_n(prompt_ids.data(), prompt_ids.get_size(), 0); @@ -185,10 +198,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { OPENVINO_ASSERT((generation_config.is_greedy_decoding() || generation_config.is_multinomial() || !streamer_ptr), "Currently streaming is possible only for greedy or multinomial decoding"); - ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, { 1, history_size + inputs_embeds.get_shape()[1] }}; + ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, { 1, history_size + inputs_embeds_size }}; std::fill_n(new_atten_mask.data(), new_atten_mask.get_size(), 1); - ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds.get_shape()[1] }}; + ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }}; std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size); ov::genai::EncodedResults encoded_result; @@ -211,6 +224,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_language.reset_state(); m_language.get_tensor("attention_mask").set_shape({1, 0}); } + auto generate_end_time = std::chrono::steady_clock::now(); decoded.perf_metrics = encoded_result.perf_metrics; @@ -228,6 +242,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { // Evaluate statistics decoded.perf_metrics.m_evaluated = false; decoded.perf_metrics.evaluate_statistics(generate_start_time); + + m_inputs_embedder->update_tokenized_chat_history(encoded_result.tokens[0]); + return decoded; }