Skip to content

Commit

Permalink
fill prompt for sampler analysis with real tokens in VLM pipeline (#1247
Browse files Browse the repository at this point in the history
)

+ add missed token, if prev generation was finished because max length
was reached
  • Loading branch information
sbalandi authored Dec 19, 2024
1 parent 7a02d2b commit 17f4eb3
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 27 deletions.
8 changes: 8 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,14 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
}
}

ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front) {
ov::Tensor new_tensor = ov::Tensor{ov::element::i64, {base_tensor.get_shape().at(0), base_tensor.get_shape().at(1) + 1}};
auto new_tensor_data = new_tensor.data<int64_t>();
new_tensor_data[0] = add_to_front;
std::copy_n(base_tensor.data<int64_t>(), base_tensor.get_size(), new_tensor_data + 1);
return new_tensor;
}

void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title) {
// Specify the name of the environment variable
const char* env_var_name = "OPENVINO_LOG_LEVEL";
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model);

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);

ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front);

void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title);

} // namespace utils
Expand Down
56 changes: 36 additions & 20 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "utils.hpp"


namespace {

constexpr size_t BATCH_SIZE = 1;
Expand Down Expand Up @@ -40,10 +41,12 @@ class InputsEmbedder::IInputsEmbedder {
// Templated chat history
std::string m_templated_chat_history;
// Tokenized chat history
std::vector<int64_t> m_tokenized_chat_history;
std::vector<int64_t> m_tokenized_history;
// The number of elements, which need to remove from the end of KV cache
// removed elements will be added to inputs_ids
size_t m_to_remove_from_hist = 0;
// Tail of previous output for LM in chat mode is missing in KV cache.
std::optional<int64_t> m_last_disappeared_token = std::nullopt;

public:
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) = 0;
Expand All @@ -56,26 +59,30 @@ class InputsEmbedder::IInputsEmbedder {
return m_tokenizer;
}

std::vector<int64_t> get_tokenized_chat_history() const {
return m_tokenized_chat_history;
std::vector<int64_t> get_tokenized_history() const {
return m_tokenized_history;
}

size_t get_amount_to_remove_from_hist() const {
return m_to_remove_from_hist;
}

void update_tokenized_chat_history(std::vector<int64_t> encoded_result) {
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_chat_history));
void update_tokenized_history(std::vector<int64_t> encoded_result, bool token_will_disappear) {
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
m_to_remove_from_hist = 0;
if (token_will_disappear)
m_last_disappeared_token = encoded_result.back();
else
m_last_disappeared_token = std::nullopt;
}

virtual void start_chat(const std::string& system_message) {
m_is_chat_conversation = true;
m_to_remove_from_hist = 0;
if (!m_tokenized_chat_history.empty()) {
if (!m_tokenized_history.empty()) {
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
m_tokenized_history.clear();
}
if (system_message.empty()) {
return;
Expand All @@ -98,7 +105,7 @@ class InputsEmbedder::IInputsEmbedder {

m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
m_tokenized_history.clear();
}

protected:
Expand Down Expand Up @@ -165,37 +172,46 @@ class InputsEmbedder::IInputsEmbedder {
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
// so let's check it out, find the trusted part and use it in on the next step
size_t last_same_hist_token = 0;
if (!m_tokenized_chat_history.empty()) {
if (!m_tokenized_history.empty()) {
std::set<int64_t> stop_tokens = {m_tokenizer.get_eos_token_id()};
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens);
}

if (m_tokenized_chat_history.empty()) {
if (m_tokenized_history.empty()) {
encoded_input_ids = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;
m_to_remove_from_hist = m_tokenized_history.size() - last_same_hist_token;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_to_remove_from_hist -= m_last_disappeared_token.has_value() ? 1 : 0;

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(),
{1, new_chat_tokens.get_shape().at(1) - last_same_hist_token},
new_chat_tokens.data<int64_t>() + last_same_hist_token);
encoded_input_ids = new_tensor;
encoded_input_ids = ov::Tensor(new_chat_tokens.get_element_type(),
{1, new_chat_tokens.get_shape().at(1) - last_same_hist_token});
new_tensor.copy_to(encoded_input_ids);
} else {
encoded_input_ids = utils::subtract_chat_tokenized_inputs(
{new_chat_tokens}, prev_chat_tokens
).input_ids;

if (m_last_disappeared_token.has_value())
encoded_input_ids = ov::genai::utils::push_front_inputs(encoded_input_ids, *m_last_disappeared_token);
}
auto end_tokenizer_time = std::chrono::steady_clock::now();
metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
m_templated_chat_history = std::move(new_templated_chat_history);
m_tokenized_chat_history.clear();
std::copy(new_chat_tokens.data<int64_t>(), new_chat_tokens.data<int64_t>() + new_chat_tokens.get_size(),
std::back_inserter(m_tokenized_chat_history));
m_tokenized_history.clear();
std::copy_n(new_chat_tokens.data<int64_t>(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history));
} else {
auto start_tokenizer_time = std::chrono::steady_clock::now();
encoded_input_ids = m_tokenizer.encode(prompt).input_ids;
auto end_tokenizer_time = std::chrono::steady_clock::now();
metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
m_tokenized_history.clear();
std::copy_n(encoded_input_ids.data<int64_t>(), encoded_input_ids.get_size(), std::back_inserter(m_tokenized_history));
}

return encoded_input_ids;
}

Expand Down Expand Up @@ -1172,12 +1188,12 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
return m_impl->get_embedding_model();
}

std::vector<int64_t> InputsEmbedder::get_tokenized_chat_history() const {
return m_impl->get_tokenized_chat_history();
std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
return m_impl->get_tokenized_history();
}

void InputsEmbedder::update_tokenized_chat_history(std::vector<int64_t> encoded_result) {
return m_impl->update_tokenized_chat_history(encoded_result);
void InputsEmbedder::update_tokenized_history(std::vector<int64_t> encoded_result, bool token_will_disappear) {
return m_impl->update_tokenized_history(encoded_result, token_will_disappear);
}

size_t InputsEmbedder::get_amount_to_remove_from_hist() const {
Expand Down
8 changes: 6 additions & 2 deletions src/cpp/src/visual_language/inputs_embedder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,20 @@ class InputsEmbedder {
Tokenizer get_tokenizer() const;

// returns tokenized chat history
std::vector<int64_t> get_tokenized_chat_history() const;
std::vector<int64_t> get_tokenized_history() const;

// add new results to tokenized chat history
void update_tokenized_chat_history(std::vector<int64_t> encoded_result);
void update_tokenized_history(std::vector<int64_t> encoded_result, bool token_will_disappear);

// returns amount of elements, which need to remove from the end of the KV cache
size_t get_amount_to_remove_from_hist() const;

// starts chat and adds optional system_message to chat history
void start_chat(const std::string& system_message);

// adds currently generated text to chat history
void update_chat_history(const std::string& decoded_results);

// finishes chat and clears a chat history
void finish_chat();
private:
Expand Down
12 changes: 7 additions & 5 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "utils.hpp"
#include "lm_encoding.hpp"


using namespace ov::genai;

namespace {
Expand Down Expand Up @@ -163,19 +164,18 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist();
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);

Sampler sampler = Sampler(m_tokenizer);

std::vector<SequenceGroup::Ptr> requests;
size_t request_id = 0;
size_t block_size = 1; // not used
bool enable_prefix_caching = false;

auto tokenized_chat_history = m_inputs_embedder->get_tokenized_chat_history();
size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - to_remove_from_hist;
size_t inputs_embeds_size = inputs_embeds.get_shape().at(1);

auto tokenized_history = m_inputs_embedder->get_tokenized_history();
ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size });
std::fill_n(prompt_ids.data<int64_t>(), prompt_ids.get_size(), 0);
std::fill_n(prompt_ids.data<int64_t>(), prompt_ids.get_size(), m_tokenizer.get_pad_token_id());
std::copy(tokenized_history.begin(), tokenized_history.end(), prompt_ids.data<int64_t>());

SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, prompt_ids, generation_config, block_size, enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);
Expand Down Expand Up @@ -204,6 +204,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), history_size);

Sampler sampler = Sampler(m_tokenizer);

ov::genai::EncodedResults encoded_result;
int32_t m_selected_beam = 0;
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests,
Expand Down Expand Up @@ -243,7 +245,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
decoded.perf_metrics.m_evaluated = false;
decoded.perf_metrics.evaluate_statistics(generate_start_time);

m_inputs_embedder->update_tokenized_chat_history(encoded_result.tokens[0]);
m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH);

return decoded;
}
Expand Down

0 comments on commit 17f4eb3

Please sign in to comment.