Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Dec 20, 2024
1 parent fa36a48 commit e550db6
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 34 deletions.
13 changes: 5 additions & 8 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}
ov::genai::EncodedResults result = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

if (is_chat_conversation) {
// force remove from kv_cache last answer
Expand All @@ -382,15 +384,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

// There's only one request in chat mode
if (requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || requests[0]->handle_dropped())
m_last_disappeared_token = result.tokens[0].back();
else
m_last_disappeared_token = std::nullopt;

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
Expand Down
28 changes: 18 additions & 10 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


EncodedResults get_lm_encoded_results(
std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
Expand Down Expand Up @@ -107,6 +107,7 @@ EncodedResults get_lm_encoded_results(
m_llm.set_tensor("beam_idx", beam_idx);

// "Prompt" phase

const auto infer_start = std::chrono::steady_clock::now();
m_llm.infer();
const auto infer_end = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -210,21 +211,28 @@ EncodedResults get_lm_encoded_results(
}

for (auto& sequence_group : sequence_groups) {
// sequences is sorted by cumulative_log_prob with length_penalty
auto outputs = sequence_group->get_finished_sequences();

auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size());
for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) {
const auto& output = outputs[output_idx];
results.tokens.push_back(output->get_generated_ids());
results.scores.push_back(output->get_beam_search_score(sequence_group->get_sampling_parameters()));
auto sampling_params = sequence_group->get_sampling_parameters();
const auto& sequences = sequence_group->get_finished_sequences();
size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size());

for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) {
const auto & sequence = sequences[seq_id];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs();

results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
}
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

return results;
// it is not saved in KV cache, we need to add it for some cases
std::optional<int64_t> last_token_of_best_sequence = std::nullopt;
if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped())
last_token_of_best_sequence = results.tokens[0].back();

return {results, last_token_of_best_sequence};
}

} // namespace genai
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/lm_encoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
namespace ov {
namespace genai {

EncodedResults get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids, std::optional<EmbeddingsModel> m_embedding);
std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids, std::optional<EmbeddingsModel> m_embedding);

}
}
13 changes: 5 additions & 8 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,16 @@ class InputsEmbedder::IInputsEmbedder {
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

void update_tokenized_history(const std::vector<int64_t>& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len) {
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
if (is_beam_search) {
m_kv_history_manager.trusted_history_length = m_tokenized_history.size();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
} else {
m_kv_history_manager.reset();
}

if (is_last_token_disappear)
m_last_disappeared_token = encoded_result.back();
else
m_last_disappeared_token = std::nullopt;

m_last_disappeared_token = last_disappeared_token;

std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
}

Expand Down Expand Up @@ -1208,8 +1205,8 @@ std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
return m_impl->get_tokenized_history();
}

void InputsEmbedder::update_tokenized_history(const std::vector<int64_t>& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len) {
return m_impl->update_tokenized_history(encoded_result, is_last_token_disappear, is_beam_search, last_answer_len);
void InputsEmbedder::update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
return m_impl->update_tokenized_history(encoded_result, last_disappeared_token, is_beam_search, last_answer_len);
}

size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/inputs_embedder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class InputsEmbedder {
std::vector<int64_t> get_tokenized_history() const;

// add new results to tokenized history
void update_tokenized_history(const std::vector<int64_t>& encoded_result, bool is_last_token_disappear, bool is_beam_search, size_t last_answer_len);
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len);

// returns amount of elements, which need to remove from the end of the KV cache
size_t get_num_tokens_to_remove_from_hist() const;
Expand Down
10 changes: 6 additions & 4 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
m_sampler.set_seed(generation_config.rng_seed);
}

ov::genai::EncodedResults encoded_result = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, m_embedding);
ov::genai::EncodedResults encoded_result;
std::optional<int64_t> last_disappeared_token;
std::tie(encoded_result, last_disappeared_token) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, m_embedding);

auto decode_start_time = std::chrono::steady_clock::now();
VLMDecodedResults decoded;
Expand All @@ -228,8 +230,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
}
auto decode_end_time = std::chrono::steady_clock::now();

m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH,
generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));
m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], last_disappeared_token, generation_config.is_beam_search(),
m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));

std::string decoded_results = decoded.texts.at(0);
if (m_is_chat_conversation) {
Expand Down

0 comments on commit e550db6

Please sign in to comment.