Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Dec 12, 2024
1 parent f4f28c2 commit c82d6d1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
39 changes: 27 additions & 12 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ModelRunner {
block_indices_begins_data[0] = 0;

bool matmul_gathering_is_required = false;
int64_t gathering_current_index = 0;
int64_t total_tokens_to_schedule = 0;
std::vector<int64_t> gather_indice_values;
try {
std::ignore = m_request.get_tensor("sampled_tokens_indices");
Expand All @@ -130,25 +130,41 @@ class ModelRunner {
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
size_t prompt_len = sequence_group->get_prompt_len();
size_t seq_len_after_gather = 0;
size_t actual_seq_len = 0;
bool echo_output = sequence_group->get_sampling_parameters().echo;
bool sampling_is_required = sequence_group->requires_sampling();

if (matmul_gathering_is_required){
if (sequence_group->requires_sampling() || echo_output) {
size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();
size_t tokens_to_sample_per_group = tokens_to_sample_per_sequence * num_running_sequences;
actual_seq_len = tokens_to_sample_per_group;

size_t initial_size = gather_indice_values.size();
gather_indice_values.resize(initial_size + tokens_to_sample_per_group);
auto it = gather_indice_values.begin() + initial_size;

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id, it += tokens_to_sample_per_sequence, total_tokens_to_schedule += num_scheduled_tokens) {
std::iota(it, it + tokens_to_sample_per_sequence, total_tokens_to_schedule);
}
} else {
total_tokens_to_schedule += num_scheduled_tokens * num_running_sequences;
actual_seq_len = num_scheduled_tokens;
}
} else {
actual_seq_len = num_scheduled_tokens;
}

sequence_group->set_seq_len_to_sample(actual_seq_len);


for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
Sequence::CPtr sequence = running_sequences[seq_id];

for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];

if (matmul_gathering_is_required && sampling_is_required) {
if (group_position_id + token_id >= prompt_len - 1 || echo_output) {
gather_indice_values.push_back(gathering_current_index);
seq_len_after_gather++;
}
}
position_ids_data[token_id] = position_id;
}

Expand All @@ -167,7 +183,6 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens);
}

// typical LLM parameters
Expand Down
3 changes: 1 addition & 2 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->is_matmul_sliced() ?
sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand Down
10 changes: 2 additions & 8 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ class SequenceGroup {
bool m_is_gen_paused = false;
// seq len to sample at current iteration
size_t m_seq_len_to_sample = 0;
// flag shows wheather last matmul was sliced
bool m_sliced_matmul = false;

SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: m_request_id(request_id),
Expand Down Expand Up @@ -399,11 +397,6 @@ class SequenceGroup {

void set_seq_len_to_sample(size_t len) {
m_seq_len_to_sample = len;
m_sliced_matmul = true;
}

bool is_matmul_sliced() const {
return m_sliced_matmul;
}

/**
Expand Down Expand Up @@ -450,13 +443,14 @@ class SequenceGroup {

void schedule_tokens(size_t num_tokens) {
m_num_scheduled_tokens = num_tokens;
// Unless otherwise specified, the sampler will process all scheduled tokens.
m_seq_len_to_sample = num_tokens;
}

void clear_scheduled_tokens() {
m_num_scheduled_tokens = 0;
m_num_validation_tokens = 0;
m_seq_len_to_sample = 0;
m_sliced_matmul = false;
}

bool is_scheduled() const {
Expand Down

0 comments on commit c82d6d1

Please sign in to comment.