diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index f9a052d978..e226de29a1 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -122,6 +122,8 @@ class ModelRunner { matmul_gathering_is_required = true; } catch (const ov::Exception&) {} + size_t gathering_current_index = 0; + for (size_t i = 0; i < num_sequence_groups; ++i) { size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; @@ -130,42 +132,31 @@ class ModelRunner { size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); size_t group_position_id = sequence_group->get_num_processed_tokens(); size_t prompt_len = sequence_group->get_prompt_len(); + + // Next variables are only for sliced matmul case size_t actual_seq_len = 0; bool echo_output = sequence_group->get_sampling_parameters().echo; - - if (matmul_gathering_is_required){ - if (sequence_group->requires_sampling() || echo_output) { - size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate(); - size_t tokens_to_sample_per_group = tokens_to_sample_per_sequence * num_running_sequences; - actual_seq_len = std::min(tokens_to_sample_per_group, num_scheduled_tokens); - - size_t initial_size = gather_indice_values.size(); - gather_indice_values.resize(initial_size + tokens_to_sample_per_group); - auto it = gather_indice_values.begin() + initial_size; - - for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id, it += tokens_to_sample_per_sequence, total_tokens_to_schedule += num_scheduled_tokens) { - std::iota(it, it + tokens_to_sample_per_sequence, total_tokens_to_schedule + num_scheduled_tokens - 1); - } - } else { - total_tokens_to_schedule += num_scheduled_tokens * num_running_sequences; - actual_seq_len = num_scheduled_tokens; - } - } else { - actual_seq_len = num_scheduled_tokens; - } - - sequence_group->set_seq_len_to_sample(actual_seq_len); - + bool sampling_is_required = sequence_group->requires_sampling(); + size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate(); for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) { Sequence::CPtr sequence = running_sequences[seq_id]; - for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) { + for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { // compute token for current sequence - input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ? + input_ids_data[token_id] = position_id < prompt_len ? sequence_group->get_prompt_ids()[position_id] : - sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()]; + sequence->get_generated_ids()[position_id - prompt_len]; position_ids_data[token_id] = position_id; + + if (matmul_gathering_is_required && sampling_is_required) { + if (echo_output || + group_position_id + token_id >= prompt_len - 1 && + group_position_id + token_id >= num_scheduled_tokens - tokens_to_sample_per_sequence) { + gather_indice_values.push_back(gathering_current_index); + actual_seq_len++; + } + } } size_t expected_kv_cache_size = sequence_group->get_num_processed_tokens() - sequence_group->get_num_evicted_tokens(); @@ -183,6 +174,7 @@ class ModelRunner { subsequence_begins_data += 1; block_indices_begins_data += 1; } + sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens); } // typical LLM parameters