Skip to content

Commit

Permalink
Use whole history in case of undetermined tokenization of sequence (#…
Browse files Browse the repository at this point in the history
…1254)

Task: [CVS-157295](https://jira.devtools.intel.com/browse/CVS-157295)

- fist commit is cherry-pick from
#1268 and
#1361
- next commit includes applying comments from
#1268 and adding
usage of kv cache for LLM
  • Loading branch information
sbalandi authored Dec 16, 2024
1 parent 8ce5eb3 commit 9e9b409
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 37 deletions.
108 changes: 90 additions & 18 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ std::pair<EncodedResults, int32_t> beam_search(
class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;

bool is_chat_conversation = false;
bool m_is_cache_empty = true;
bool m_trust_encoded_history = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
TokenizedInputs m_tokenized_chat_history;
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -77,6 +79,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
ov::Core core;
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
utils::slice_matmul_statefull_model(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

if (auto filtered_plugin_config = extract_adapters_from_properties(plugin_config, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
Expand All @@ -102,8 +105,20 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
OptionalGenerationConfig generation_config,
StreamerVariant streamer
) override {
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::STRING;

if (is_chat_conversation)
OPENVINO_ASSERT(m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS,
"Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat.");

auto start_time = std::chrono::steady_clock::now();
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
// If eos_token_id was not provided, take value from default m_generation_config
if (config.eos_token_id == -1)
config.set_eos_token_id(m_generation_config.eos_token_id);
config.validate();

TokenizedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
Expand All @@ -127,19 +142,51 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
if (m_is_cache_empty) {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));

// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
// so let's check it out, find the trusted part and use it in on the next step
size_t last_same_hist_token = 0;
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
new_chat_tokens.input_ids.data<int64_t>() + last_same_hist_token);

ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape());
std::fill_n(new_attention_mask.data<int64_t>(), new_tensor.get_shape()[1], 1);

encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
new_tensor.copy_to(encoded_input.input_ids);
encoded_input.attention_mask = new_attention_mask;

m_selected_beam = std::nullopt;
} else {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
m_tokenized_chat_history = new_chat_tokens;
m_tokenized_chat_history.clear();
m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size());
std::copy_n(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.get_size(),
std::back_inserter(m_tokenized_chat_history));

// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
auto encoded_results = generate(encoded_input, config, streamer);

Expand Down Expand Up @@ -188,6 +235,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
OptionalGenerationConfig generation_config,
StreamerVariant streamer
) override {
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS;

if (is_chat_conversation)
// if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role
OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user",
"Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat.");

auto start_time = std::chrono::steady_clock::now();
ov::Tensor input_ids;
ov::Tensor attention_mask;
Expand All @@ -199,6 +254,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
attention_mask = data->attention_mask;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;

// If eos_token_id was not provided, take value from default m_generation_config
Expand Down Expand Up @@ -230,16 +288,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
if (is_chat_conversation && !m_is_cache_empty) {
if (is_chat_conversation && !m_tokenized_chat_history.empty()) {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
Expand All @@ -263,6 +322,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_adapter_controller->apply(m_model_runner, config.adapters);
}

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
Expand All @@ -274,8 +338,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation && !m_is_cache_empty) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
Expand All @@ -294,12 +359,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
sampler, requests, position_ids, std::nullopt, m_selected_beam);
}

if (!is_chat_conversation) {
if (is_chat_conversation) {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_selected_beam = std::nullopt;
} else {
m_is_cache_empty = false;
}

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -313,12 +379,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_is_cache_empty = true;
m_history = {};
m_templated_chat_history = "";
m_tokenized_chat_history.clear();
}
if (system_message.empty())
return;
Expand All @@ -332,11 +401,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_is_cache_empty = true;
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
}
}
};
Expand Down
75 changes: 75 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "openvino/op/tanh.hpp"
#include "openvino/op/transpose.hpp"

#include "sampler.hpp"

namespace ov {
namespace genai {
namespace utils {
Expand Down Expand Up @@ -306,6 +308,79 @@ ov::Core singleton_core() {
return core;
}

size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector<int64_t> tokenized_history, std::set<int64_t> stop_tokens) {
size_t idx = 0;
auto encoded_history_data = encoded_history.data<int64_t>();
while(idx < encoded_history.get_size() && idx < tokenized_history.size()) {
if (encoded_history_data[idx] != tokenized_history[idx])
break;
idx++;
}

// encoded_history after decode of tokenizer could lose one last token (eos/stop token)
if ((idx == tokenized_history.size() && idx == encoded_history.get_size()) ||
(encoded_history.get_size() < tokenized_history.size() && idx == tokenized_history.size() - 1 && stop_tokens.find(tokenized_history.back()) != stop_tokens.end()))
return SIZE_MAX;
else
return idx;
}

size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model) {
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// therefore usually seq_length_axis = 2
size_t seq_length_axis = 2;

// "ReadValue" node is KV cache representation in stateful model
std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name);

for (const auto op : model->get_ops()) {
// check input size, as in LoRA adapters case it could be 0
if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) {
continue;
}

// Shape example: [-1,4,0,64]
auto shape = op->get_input_partial_shape(0);

for (size_t i = 0; i < shape.rank().get_length(); i++) {
// Find axis = 0. This would be sequence length axis.
if (shape[i] == 0) {
seq_length_axis = i;
}
}
break;
}

return seq_length_axis;
}

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller) {
// nothing to trim in this case
if (remove_from_end == 0)
return;

auto states = request.query_state();
for (auto& state : states) {
if(adapter_controller && adapter_controller->has_state_name(state.get_name()))
continue;

ov::Tensor old_tensor = state.get_state();
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
auto shape = old_tensor.get_shape();
shape[seq_length_axis] -= remove_from_end;

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{shape};

auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);

ov::Tensor new_tensor(old_tensor.get_element_type(), shape);
trimmed_tensor.copy_to(new_tensor);

state.set_state(new_tensor);
}
}

} // namespace utils
} // namespace genai
} // namespace ov
11 changes: 11 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ constexpr bool is_container<T,
std::void_t<decltype(std::declval<T>().begin()),
decltype(std::declval<T>().end())>> = true;

enum class GenerationChatInputsType {
UNDEF = 0, // Default value, type of inputs is not defined
STRING = 1, // Type of inputs is StringInputs
ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs
};

Tensor init_attention_mask(const Tensor& position_ids);

Expand Down Expand Up @@ -93,6 +98,12 @@ ov::Core singleton_core();
template <typename T>
void read_rt_info(std::shared_ptr<ov::Model>& model, const char* name, T& value);

size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector<int64_t> tokenized_history, std::set<int64_t> stop_tokens);

size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model);

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);

} // namespace utils
} // namespace genai
} // namespace ov
Loading

0 comments on commit 9e9b409

Please sign in to comment.