Skip to content

Commit

Permalink
[ CB ] Return only N results from read_all (#678)
Browse files Browse the repository at this point in the history
CVS-146801
  • Loading branch information
iefode authored Nov 8, 2024
1 parent 747c5d2 commit 72ce6de
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
8 changes: 1 addition & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
EncodedGenerationResult result;
result.m_request_id = 1;
std::vector<GenerationOutput> generation_outputs = generation->read_all();
std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) {
return r1.score > r2.score;
});

auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
for (const auto& generation_output : generation_outputs) {
result.m_generation_ids.push_back(std::move(generation_output.generated_ids));
result.m_scores.push_back(generation_output.score);
}
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ std::vector<GenerationOutput> GenerationHandleImpl::read_all() {
add_partial_result(partial_results, iteration_results);
}

for (auto& partial_result: partial_results) {
for (auto& partial_result : partial_results) {
results.push_back(partial_result.second);
}
std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return lhs.score > rhs.score; });
results.resize(std::min(m_sampling_params.num_return_sequences, results.size()));
return results;
}

0 comments on commit 72ce6de

Please sign in to comment.