From 678580d9a1a2729b333baf265a9ca2b5aaa437cf Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Tue, 26 Nov 2024 20:40:14 +0000 Subject: [PATCH] Add slice before matmut transformation for CB scenario --- src/cpp/src/continuous_batching_impl.cpp | 1 + src/cpp/src/model_runner.hpp | 31 +++++++++++++++++-- src/cpp/src/sampler.cpp | 4 +-- src/cpp/src/sequence_group.hpp | 10 ++++++ .../speculative_decoding_impl.cpp | 2 ++ src/cpp/src/utils.cpp | 31 +++++++++++-------- src/cpp/src/utils.hpp | 2 ++ .../utils/paged_attention_transformations.cpp | 18 +++++++++++ .../utils/paged_attention_transformations.hpp | 2 ++ 9 files changed, 83 insertions(+), 18 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 73bf4ec083..e6ef03940f 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -31,6 +31,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction; utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control); + utils::apply_gather_before_matmul_transformation(model); init(model, scheduler_config, compile_properties, device_config, core); } diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index 1b96cdc505..3a161e4239 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -114,13 +114,23 @@ class ModelRunner { subsequence_begins_data[0] = 0; block_indices_begins_data[0] = 0; + bool matmul_gathering_is_required = false; + int64_t gathering_current_index = 0; + std::vector gather_indice_values; + try { + std::ignore = m_request.get_tensor("gather_indices"); + matmul_gathering_is_required = true; + } catch (const ov::Exception&) {} + for (size_t i = 0; i < num_sequence_groups; ++i) { size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; - SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; - std::vector running_sequences = sequence_group->get_running_sequences(); + SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); size_t num_running_sequences = running_sequences.size(); size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); size_t group_position_id = sequence_group->get_num_processed_tokens(); + auto prompt_len = sequence_group->get_prompt_len(); + size_t tokens_num_to_sample = 0; // spec: In case of multiple input tokens for current sequence (prompt_len > 1), // context_len corresponds to first token within subgroup of scheduled tokens @@ -129,12 +139,20 @@ class ModelRunner { for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) { Sequence::CPtr sequence = running_sequences[seq_id]; - for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { // compute token for current sequence input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ? sequence_group->get_prompt_ids()[position_id] : sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()]; + if (matmul_gathering_is_required) { + if (group_position_id + token_id >= prompt_len - 1) { + gather_indice_values.push_back(gathering_current_index); + tokens_num_to_sample++; + } + } else { + tokens_num_to_sample++; + } position_ids_data[token_id] = position_id; } @@ -153,6 +171,7 @@ class ModelRunner { subsequence_begins_data += 1; block_indices_begins_data += 1; } + sequence_group->set_seq_len_to_sample(tokens_num_to_sample); } // typical LLM parameters @@ -168,6 +187,12 @@ class ModelRunner { m_request.set_tensor("block_indices_begins", block_indices_begins); m_request.set_tensor("max_context_len", max_context_len); + if (matmul_gathering_is_required) { + ov::Tensor gather_indices(ov::element::i64, {gather_indice_values.size()}); + std::memcpy(gather_indices.data(), gather_indice_values.data(), gather_indice_values.size() * sizeof(int64_t)); + m_request.set_tensor("gather_indices", gather_indices); + } + // print_tensor("input_ids", input_ids); // print_tensor("position_ids", position_ids); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 3febadf112..c0c90bccaa 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -755,8 +755,8 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, continue; size_t num_running_sequences = sequence_group->num_running_seqs(); - size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled - size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len); + size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); // points to a token which needs to be sampled + size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len); const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); const auto request_id = sequence_group->get_request_id(); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index c5be82f0f2..64f4ed3828 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -220,6 +220,8 @@ class SequenceGroup { size_t m_num_validation_tokens = 0; // flag to enable/disable token generation, e.g. in speculative decoding scenario bool m_is_gen_paused = false; + // seq len to sample at current iteration + size_t m_seq_len_to_sample = 0; SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) @@ -390,6 +392,14 @@ class SequenceGroup { return m_num_processed_tokens; } + size_t get_seq_len_to_sample() const { + return m_seq_len_to_sample; + } + + void set_seq_len_to_sample(size_t len) { + m_seq_len_to_sample = len; + } + /** * Registers within the sequence group that a given amount of tokens * has been evicted from the underlying KV cache. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 0f43555a5f..abbabd6719 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -42,6 +42,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl( utils::apply_paged_attention_transformations(main_model, main_scheduler_config.use_cache_eviction); utils::apply_paged_attention_transformations(draft_model, main_scheduler_config.use_cache_eviction); + utils::apply_gather_before_matmul_transformation(main_model); + utils::apply_gather_before_matmul_transformation(draft_model); std::string draft_device = draft_model_desc.device.empty() ? main_device : draft_model_desc.device; diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 50c2e0c49e..eaabd35897 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -4,6 +4,7 @@ #include "utils.hpp" #include +#include #include "openvino/op/add.hpp" #include "openvino/op/divide.hpp" @@ -234,23 +235,27 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token return {new_input_ids, new_attention_mask}; } -void slice_matmul_statefull_model(std::shared_ptr model) { - auto last_node = model->output(0).get_node()->input_value(0).get_node(); - ov::Node* matmul = dynamic_cast(last_node); - if (matmul) { - // we have found matmul, do nothing - } else if(auto add = dynamic_cast(last_node)) { - matmul = dynamic_cast(add->input_value(0).get_node()); - } else if (auto transpose = dynamic_cast(last_node)) { - matmul = dynamic_cast(transpose->input_value(0).get_node()); - } else if (auto multiply = dynamic_cast(last_node)) { - if (auto tanh = dynamic_cast(multiply->input_value(0).get_node())) { - if (auto divide = dynamic_cast(tanh->input_value(0).get_node())) { - matmul = dynamic_cast(divide->input_value(0).get_node()); +std::shared_ptr find_llm_matmul(const std::shared_ptr& model) { + auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr(); + std::shared_ptr matmul = std::dynamic_pointer_cast(last_node); + if (!matmul) { + if(auto add = std::dynamic_pointer_cast(last_node)) { + matmul = std::dynamic_pointer_cast(add->input_value(0).get_node_shared_ptr()); + } else if (auto transpose = std::dynamic_pointer_cast(last_node)) { + matmul = std::dynamic_pointer_cast(transpose->input_value(0).get_node_shared_ptr()); + } else if (auto multiply = std::dynamic_pointer_cast(last_node)) { + if (auto tanh = std::dynamic_pointer_cast(multiply->input_value(0).get_node_shared_ptr())) { + if (auto divide = std::dynamic_pointer_cast(tanh->input_value(0).get_node_shared_ptr())) { + matmul = std::dynamic_pointer_cast(divide->input_value(0).get_node_shared_ptr()); + } } } } + return matmul; +} +void slice_matmul_statefull_model(std::shared_ptr model) { + auto matmul = find_llm_matmul(model); if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 3487fccb81..9bbb5f3ac9 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -82,6 +82,8 @@ std::pair split_core_complile_config(const ov::AnyMap& p ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend); +std::shared_ptr find_llm_matmul(const std::shared_ptr& model); + void slice_matmul_statefull_model(std::shared_ptr model); ov::Core singleton_core(); diff --git a/src/cpp/src/utils/paged_attention_transformations.cpp b/src/cpp/src/utils/paged_attention_transformations.cpp index 53690f770c..1bdc6615e4 100644 --- a/src/cpp/src/utils/paged_attention_transformations.cpp +++ b/src/cpp/src/utils/paged_attention_transformations.cpp @@ -5,6 +5,11 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/sdpa_to_paged_attention.hpp" +#include "utils.hpp" + +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/parameter.hpp" namespace ov { namespace genai { @@ -78,6 +83,19 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev set_kv_cache_type_and_shape(model, device_config); } +void apply_gather_before_matmul_transformation(std::shared_ptr model) { + auto matmul = ov::genai::utils::find_llm_matmul(model); + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { + auto indices = std::make_shared(ov::element::i64, ov::PartialShape{-1}); + indices->set_friendly_name("gather_indices"); + indices->output(0).get_tensor().set_names({"gather_indices"}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + auto gather = std::make_shared(matmul->input_value(0), indices, axis); + matmul->input(0).replace_source_output(gather); + model->add_parameters({indices}); + } +} + } // namespace utils } // namespace genai } // namespace ov \ No newline at end of file diff --git a/src/cpp/src/utils/paged_attention_transformations.hpp b/src/cpp/src/utils/paged_attention_transformations.hpp index 3bc423d7bc..3fe9116cc6 100644 --- a/src/cpp/src/utils/paged_attention_transformations.hpp +++ b/src/cpp/src/utils/paged_attention_transformations.hpp @@ -23,6 +23,8 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control = false); +void apply_gather_before_matmul_transformation(std::shared_ptr model); + size_t get_kv_cache_size(const std::shared_ptr model); void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config);