From c82d6d16f7d38db0a64e0da2327d49c91cc08734 Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Thu, 12 Dec 2024 11:46:12 +0000 Subject: [PATCH] Apply comments --- src/cpp/src/model_runner.hpp | 39 +++++++++++++++++++++++----------- src/cpp/src/sampler.cpp | 3 +-- src/cpp/src/sequence_group.hpp | 10 ++------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index e6c8e03e91..2a05d49e9a 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -115,7 +115,7 @@ class ModelRunner { block_indices_begins_data[0] = 0; bool matmul_gathering_is_required = false; - int64_t gathering_current_index = 0; + int64_t total_tokens_to_schedule = 0; std::vector gather_indice_values; try { std::ignore = m_request.get_tensor("sampled_tokens_indices"); @@ -130,25 +130,41 @@ class ModelRunner { size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); size_t group_position_id = sequence_group->get_num_processed_tokens(); size_t prompt_len = sequence_group->get_prompt_len(); - size_t seq_len_after_gather = 0; + size_t actual_seq_len = 0; bool echo_output = sequence_group->get_sampling_parameters().echo; - bool sampling_is_required = sequence_group->requires_sampling(); + + if (matmul_gathering_is_required){ + if (sequence_group->requires_sampling() || echo_output) { + size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate(); + size_t tokens_to_sample_per_group = tokens_to_sample_per_sequence * num_running_sequences; + actual_seq_len = tokens_to_sample_per_group; + + size_t initial_size = gather_indice_values.size(); + gather_indice_values.resize(initial_size + tokens_to_sample_per_group); + auto it = gather_indice_values.begin() + initial_size; + + for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id, it += tokens_to_sample_per_sequence, total_tokens_to_schedule += num_scheduled_tokens) { + std::iota(it, it + tokens_to_sample_per_sequence, total_tokens_to_schedule); + } + } else { + total_tokens_to_schedule += num_scheduled_tokens * num_running_sequences; + actual_seq_len = num_scheduled_tokens; + } + } else { + actual_seq_len = num_scheduled_tokens; + } + + sequence_group->set_seq_len_to_sample(actual_seq_len); + for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) { Sequence::CPtr sequence = running_sequences[seq_id]; - - for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { // compute token for current sequence input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ? sequence_group->get_prompt_ids()[position_id] : sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()]; - if (matmul_gathering_is_required && sampling_is_required) { - if (group_position_id + token_id >= prompt_len - 1 || echo_output) { - gather_indice_values.push_back(gathering_current_index); - seq_len_after_gather++; - } - } position_ids_data[token_id] = position_id; } @@ -167,7 +183,6 @@ class ModelRunner { subsequence_begins_data += 1; block_indices_begins_data += 1; } - sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens); } // typical LLM parameters diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index b9ffaa77f7..6b3578d90c 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -756,8 +756,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, continue; size_t num_running_sequences = sequence_group->num_running_seqs(); - size_t actual_seq_len = sequence_group->is_matmul_sliced() ? - sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled + size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); const auto request_id = sequence_group->get_request_id(); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 8ec6306d4a..194a04e2dd 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -222,8 +222,6 @@ class SequenceGroup { bool m_is_gen_paused = false; // seq len to sample at current iteration size_t m_seq_len_to_sample = 0; - // flag shows wheather last matmul was sliced - bool m_sliced_matmul = false; SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -399,11 +397,6 @@ class SequenceGroup { void set_seq_len_to_sample(size_t len) { m_seq_len_to_sample = len; - m_sliced_matmul = true; - } - - bool is_matmul_sliced() const { - return m_sliced_matmul; } /** @@ -450,13 +443,14 @@ class SequenceGroup { void schedule_tokens(size_t num_tokens) { m_num_scheduled_tokens = num_tokens; + // Unless otherwise specified, the sampler will process all scheduled tokens. + m_seq_len_to_sample = num_tokens; } void clear_scheduled_tokens() { m_num_scheduled_tokens = 0; m_num_validation_tokens = 0; m_seq_len_to_sample = 0; - m_sliced_matmul = false; } bool is_scheduled() const {