Skip to content

Commit

Permalink
Add slice before matmut transformation for CB scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Nov 27, 2024
1 parent 3da2aeb commit 678580d
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
utils::apply_gather_before_matmul_transformation(model);

init(model, scheduler_config, compile_properties, device_config, core);
}
Expand Down
31 changes: 28 additions & 3 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,23 @@ class ModelRunner {
subsequence_begins_data[0] = 0;
block_indices_begins_data[0] = 0;

bool matmul_gathering_is_required = false;
int64_t gathering_current_index = 0;
std::vector<int64_t> gather_indice_values;
try {
std::ignore = m_request.get_tensor("gather_indices");
matmul_gathering_is_required = true;
} catch (const ov::Exception&) {}

for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
auto prompt_len = sequence_group->get_prompt_len();
size_t tokens_num_to_sample = 0;

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
Expand All @@ -129,12 +139,20 @@ class ModelRunner {
for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
Sequence::CPtr sequence = running_sequences[seq_id];

for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];

if (matmul_gathering_is_required) {
if (group_position_id + token_id >= prompt_len - 1) {
gather_indice_values.push_back(gathering_current_index);
tokens_num_to_sample++;
}
} else {
tokens_num_to_sample++;
}
position_ids_data[token_id] = position_id;
}

Expand All @@ -153,6 +171,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(tokens_num_to_sample);
}

// typical LLM parameters
Expand All @@ -168,6 +187,12 @@ class ModelRunner {
m_request.set_tensor("block_indices_begins", block_indices_begins);
m_request.set_tensor("max_context_len", max_context_len);

if (matmul_gathering_is_required) {
ov::Tensor gather_indices(ov::element::i64, {gather_indice_values.size()});
std::memcpy(gather_indices.data(), gather_indice_values.data(), gather_indice_values.size() * sizeof(int64_t));
m_request.set_tensor("gather_indices", gather_indices);
}

// print_tensor("input_ids", input_ids);
// print_tensor("position_ids", position_ids);

Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len);
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand Down
10 changes: 10 additions & 0 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class SequenceGroup {
size_t m_num_validation_tokens = 0;
// flag to enable/disable token generation, e.g. in speculative decoding scenario
bool m_is_gen_paused = false;
// seq len to sample at current iteration
size_t m_seq_len_to_sample = 0;


SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
Expand Down Expand Up @@ -390,6 +392,14 @@ class SequenceGroup {
return m_num_processed_tokens;
}

size_t get_seq_len_to_sample() const {
return m_seq_len_to_sample;
}

void set_seq_len_to_sample(size_t len) {
m_seq_len_to_sample = len;
}

/**
* Registers within the sequence group that a given amount of tokens
* has been evicted from the underlying KV cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(

utils::apply_paged_attention_transformations(main_model, main_scheduler_config.use_cache_eviction);
utils::apply_paged_attention_transformations(draft_model, main_scheduler_config.use_cache_eviction);
utils::apply_gather_before_matmul_transformation(main_model);
utils::apply_gather_before_matmul_transformation(draft_model);

std::string draft_device = draft_model_desc.device.empty() ? main_device : draft_model_desc.device;

Expand Down
31 changes: 18 additions & 13 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "utils.hpp"

#include <fstream>
#include <memory>

#include "openvino/op/add.hpp"
#include "openvino/op/divide.hpp"
Expand Down Expand Up @@ -234,23 +235,27 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token
return {new_input_ids, new_attention_mask};
}

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node();
ov::Node* matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node);
if (matmul) {
// we have found matmul, do nothing
} else if(auto add = dynamic_cast<ov::op::v1::Add*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(add->input_value(0).get_node());
} else if (auto transpose = dynamic_cast<ov::op::v1::Transpose*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(transpose->input_value(0).get_node());
} else if (auto multiply = dynamic_cast<ov::op::v1::Multiply*>(last_node)) {
if (auto tanh = dynamic_cast<ov::op::v0::Tanh*>(multiply->input_value(0).get_node())) {
if (auto divide = dynamic_cast<ov::op::v1::Divide*>(tanh->input_value(0).get_node())) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(divide->input_value(0).get_node());
std::shared_ptr<ov::Node> find_llm_matmul(const std::shared_ptr<ov::Model>& model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr();
std::shared_ptr<ov::Node> matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(last_node);
if (!matmul) {
if(auto add = std::dynamic_pointer_cast<ov::op::v1::Add>(last_node)) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->input_value(0).get_node_shared_ptr());
} else if (auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(last_node)) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(transpose->input_value(0).get_node_shared_ptr());
} else if (auto multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(last_node)) {
if (auto tanh = std::dynamic_pointer_cast<ov::op::v0::Tanh>(multiply->input_value(0).get_node_shared_ptr())) {
if (auto divide = std::dynamic_pointer_cast<ov::op::v1::Divide>(tanh->input_value(0).get_node_shared_ptr())) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(divide->input_value(0).get_node_shared_ptr());
}
}
}
}
return matmul;
}

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
auto matmul = find_llm_matmul(model);
if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ std::pair<ov::AnyMap, ov::AnyMap> split_core_complile_config(const ov::AnyMap& p

ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend);

std::shared_ptr<ov::Node> find_llm_matmul(const std::shared_ptr<ov::Model>& model);

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model);

ov::Core singleton_core();
Expand Down
18 changes: 18 additions & 0 deletions src/cpp/src/utils/paged_attention_transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

#include "openvino/pass/manager.hpp"
#include "openvino/pass/sdpa_to_paged_attention.hpp"
#include "utils.hpp"

#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/parameter.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -78,6 +83,19 @@ void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, Dev
set_kv_cache_type_and_shape(model, device_config);
}

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model) {
auto matmul = ov::genai::utils::find_llm_matmul(model);
if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto indices = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{-1});
indices->set_friendly_name("gather_indices");
indices->output(0).get_tensor().set_names({"gather_indices"});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{0});
auto gather = std::make_shared<ov::op::v8::Gather>(matmul->input_value(0), indices, axis);
matmul->input(0).replace_source_output(gather);
model->add_parameters({indices});
}
}

} // namespace utils
} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/cpp/src/utils/paged_attention_transformations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, Dev

void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, bool per_layer_cache_control = false);

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model);

size_t get_kv_cache_size(const std::shared_ptr<ov::Model> model);

void set_kv_cache_type_and_shape(std::shared_ptr<ov::Model> model, DeviceConfig& device_config);
Expand Down

0 comments on commit 678580d

Please sign in to comment.