Skip to content

Commit

Permalink
Fixed sorting in get_finished_sequences()
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 20, 2024
1 parent 5e44ca8 commit 94f283e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,6 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished

std::vector<EncodedGenerationResult> results;
results.reserve(all_requests.size());

bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
try {
Expand Down Expand Up @@ -313,6 +310,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
}

std::vector<EncodedGenerationResult> results;
results.reserve(all_requests.size());

for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) {
const auto& request = all_requests[request_id];
auto sampling_params = request->get_sampling_parameters();
Expand All @@ -322,6 +322,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
EncodedGenerationResult result;
result.m_request_id = request_id;
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
Expand All @@ -331,7 +332,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
if (sampling_params.echo)
result.m_generation_ids[i] = request->get_prompt_ids();
std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i]));
result.m_scores.push_back(score);
result.m_scores[i] = score;
}

result.m_status = generations[request_id]->get_status();
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ class SequenceGroup {
}
}

std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) {
std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) -> bool {
bool is_beam_search = m_sampling_params.is_beam_search();
const float score_1 = is_beam_search ? s1->get_cumulative_log_probs() : s1->get_beam_search_score(m_sampling_params);
const float score_2 = is_beam_search ? s2->get_cumulative_log_probs() : s2->get_beam_search_score(m_sampling_params);
const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_probs();
const float score_2 = is_beam_search ? s2->get_beam_search_score(m_sampling_params) : s2->get_cumulative_log_probs();
return score_1 > score_2;
});

Expand Down

0 comments on commit 94f283e

Please sign in to comment.