Skip to content

Commit

Permalink
Rework gather input calculating
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Dec 18, 2024
1 parent 2931677 commit 3a6066a
Showing 1 changed file with 20 additions and 29 deletions.
49 changes: 20 additions & 29 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ class ModelRunner {
block_indices_begins_data[0] = 0;

bool matmul_gathering_is_required = false;
int64_t total_tokens_to_schedule = 0;
size_t gathering_current_index = 0;
std::vector<int64_t> gather_indice_values;
try {
std::ignore = m_request.get_tensor("sampled_tokens_indices");
matmul_gathering_is_required = true;
} catch (const ov::Exception&) {}


for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
Expand All @@ -130,42 +131,31 @@ class ModelRunner {
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
size_t prompt_len = sequence_group->get_prompt_len();
size_t actual_seq_len = 0;
bool echo_output = sequence_group->get_sampling_parameters().echo;

if (matmul_gathering_is_required){
if (sequence_group->requires_sampling() || echo_output) {
size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();
size_t tokens_to_sample_per_group = tokens_to_sample_per_sequence * num_running_sequences;
actual_seq_len = std::min(tokens_to_sample_per_group, num_scheduled_tokens);

size_t initial_size = gather_indice_values.size();
gather_indice_values.resize(initial_size + tokens_to_sample_per_group);
auto it = gather_indice_values.begin() + initial_size;

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id, it += tokens_to_sample_per_sequence, total_tokens_to_schedule += num_scheduled_tokens) {
std::iota(it, it + tokens_to_sample_per_sequence, total_tokens_to_schedule + num_scheduled_tokens - 1);
}
} else {
total_tokens_to_schedule += num_scheduled_tokens * num_running_sequences;
actual_seq_len = num_scheduled_tokens;
}
} else {
actual_seq_len = num_scheduled_tokens;
}

sequence_group->set_seq_len_to_sample(actual_seq_len);

// Next variables are only for sliced matmul case
size_t actual_seq_len = 0;
const bool echo_output = sequence_group->get_sampling_parameters().echo;
const bool sampling_is_required = sequence_group->requires_sampling();
const size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
Sequence::CPtr sequence = running_sequences[seq_id];
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
input_ids_data[token_id] = position_id < prompt_len ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];
sequence->get_generated_ids()[position_id - prompt_len];

position_ids_data[token_id] = position_id;

if (matmul_gathering_is_required && sampling_is_required) {
if (echo_output ||
group_position_id + token_id >= prompt_len - 1 &&
group_position_id + token_id >= num_scheduled_tokens - tokens_to_sample_per_sequence) {
gather_indice_values.push_back(gathering_current_index);
actual_seq_len++;
}
}
}

size_t expected_kv_cache_size = sequence_group->get_num_processed_tokens() - sequence_group->get_num_evicted_tokens();
Expand All @@ -183,6 +173,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens);
}

// typical LLM parameters
Expand Down

0 comments on commit 3a6066a

Please sign in to comment.