Skip to content

Commit

Permalink
Update streaming in LM Encoding & CB (#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode authored Dec 16, 2024
1 parent 4a7374b commit 8ce5eb3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
8 changes: 5 additions & 3 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,11 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
if (streamer_ptr && generations.at(0)->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size());
continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0));
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (!streamer_ptr->put(gen_token)) {
break;
}
}
}
}

Expand Down
25 changes: 15 additions & 10 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
active_sequence_groups.end(),
get_active_sequence_groups),
active_sequence_groups.end());

auto stream_generated_tokens = [&streamer_ptr, &generations]() {
if (streamer_ptr && generations.at(0).get()->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (!streamer_ptr->put(gen_token)) {
break;
}
}
}
};

while (active_sequence_groups.size() > 0) {
size_t total_num_tokens = 0;
Expand Down Expand Up @@ -202,13 +213,7 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
raw_perf_counters.m_new_token_times.emplace_back(infer_end);
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

if (streamer_ptr) {
// stream data from first sequence
int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back();
if (streamer_ptr->put(out_token)) {
break;
}
}
stream_generated_tokens();

sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits"));

Expand All @@ -218,9 +223,9 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
active_sequence_groups.end());
}

// to stream last token
stream_generated_tokens();
if (streamer_ptr) {
int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back();
streamer_ptr->put(out_token);
streamer_ptr->end();
}

Expand All @@ -246,4 +251,4 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
}

} // namespace genai
} // namespace ov
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
continue;
}
std::unordered_map<uint64_t, GenerationOutput> token = main_generations.at(0).get()->back();
OPENVINO_ASSERT(1 <= token.size());
OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size());
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
Expand Down

0 comments on commit 8ce5eb3

Please sign in to comment.