diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index 699e5bfccc..3462f2566a 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -129,8 +129,8 @@ class ModelRunner { size_t num_running_sequences = running_sequences.size(); size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); size_t group_position_id = sequence_group->get_num_processed_tokens(); - auto prompt_len = sequence_group->get_prompt_len(); - size_t tokens_num_to_sample = 0; + size_t prompt_len = sequence_group->get_prompt_len(); + size_t seq_len_after_gather = 0; // spec: In case of multiple input tokens for current sequence (prompt_len > 1), // context_len corresponds to first token within subgroup of scheduled tokens @@ -148,7 +148,7 @@ class ModelRunner { if (matmul_gathering_is_required) { if (group_position_id + token_id >= prompt_len - 1) { gather_indice_values.push_back(gathering_current_index); - tokens_num_to_sample++; + seq_len_after_gather++; } } position_ids_data[token_id] = position_id; @@ -169,7 +169,7 @@ class ModelRunner { subsequence_begins_data += 1; block_indices_begins_data += 1; } - sequence_group->set_seq_len_to_sample(tokens_num_to_sample); + sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens); } // typical LLM parameters diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index a6cbab44f5..526dc6ace6 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -755,9 +755,8 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, continue; size_t num_running_sequences = sequence_group->num_running_seqs(); - size_t actual_seq_len = sequence_group->is_matmul_sliced() ? - sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled - size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len); + size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); + size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len); const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); const auto request_id = sequence_group->get_request_id(); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 29e53da4b2..9670b72bdc 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -222,8 +222,6 @@ class SequenceGroup { bool m_is_gen_paused = false; // seq len to sample at current iteration size_t m_seq_len_to_sample = 0; - // flag shows wheather last matmul was sliced - bool m_sliced_matmul = false; SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -399,11 +397,6 @@ class SequenceGroup { void set_seq_len_to_sample(size_t len) { m_seq_len_to_sample = len; - m_sliced_matmul = true; - } - - bool is_matmul_sliced() const { - return m_sliced_matmul; } /** @@ -456,7 +449,6 @@ class SequenceGroup { m_num_scheduled_tokens = 0; m_num_validation_tokens = 0; m_seq_len_to_sample = 0; - m_sliced_matmul = false; } bool is_scheduled() const {