Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Nov 28, 2024
1 parent e2bc9c5 commit 1523fab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class ModelRunner {
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
auto prompt_len = sequence_group->get_prompt_len();
size_t tokens_num_to_sample = 0;
size_t prompt_len = sequence_group->get_prompt_len();
size_t seq_len_after_gather = 0;

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
Expand All @@ -148,7 +148,7 @@ class ModelRunner {
if (matmul_gathering_is_required) {
if (group_position_id + token_id >= prompt_len - 1) {
gather_indice_values.push_back(gathering_current_index);
tokens_num_to_sample++;
seq_len_after_gather++;
}
}
position_ids_data[token_id] = position_id;
Expand All @@ -169,7 +169,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(tokens_num_to_sample);
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens);
}

// typical LLM parameters
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->is_matmul_sliced() ?
sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len);
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand Down

0 comments on commit 1523fab

Please sign in to comment.