From 8ce5eb389179ba82da6523f849944ea3dc8c93e0 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Mon, 16 Dec 2024 15:49:38 +0400 Subject: [PATCH] Update streaming in LM Encoding & CB (#1377) --- src/cpp/src/continuous_batching_impl.cpp | 8 +++--- src/cpp/src/lm_encoding.cpp | 25 +++++++++++-------- .../speculative_decoding_impl.cpp | 2 -- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index d27e8934dc..1e42f5b2d9 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -285,9 +285,11 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { std::unordered_map token = generations.at(0).get()->back(); - OPENVINO_ASSERT(1 == token.size()); - OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size()); - continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0)); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (!streamer_ptr->put(gen_token)) { + break; + } + } } } diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index c76d9f7edf..3ab041fa58 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -125,6 +125,17 @@ std::pair get_lm_encoded_results( active_sequence_groups.end(), get_active_sequence_groups), active_sequence_groups.end()); + + auto stream_generated_tokens = [&streamer_ptr, &generations]() { + if (streamer_ptr && generations.at(0).get()->can_read()) { + std::unordered_map token = generations.at(0).get()->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (!streamer_ptr->put(gen_token)) { + break; + } + } + } + }; while (active_sequence_groups.size() > 0) { size_t total_num_tokens = 0; @@ -202,13 +213,7 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); - if (streamer_ptr) { - // stream data from first sequence - int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back(); - if (streamer_ptr->put(out_token)) { - break; - } - } + stream_generated_tokens(); sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits")); @@ -218,9 +223,9 @@ std::pair get_lm_encoded_results( active_sequence_groups.end()); } + // to stream last token + stream_generated_tokens(); if (streamer_ptr) { - int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back(); - streamer_ptr->put(out_token); streamer_ptr->end(); } @@ -246,4 +251,4 @@ std::pair get_lm_encoded_results( } } // namespace genai -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 2be67320a9..e4f3b1ad1f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -232,8 +232,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< continue; } std::unordered_map token = main_generations.at(0).get()->back(); - OPENVINO_ASSERT(1 <= token.size()); - OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size()); for (const auto& gen_token : token.begin()->second.generated_ids) { continue_generation = !streamer_ptr->put(gen_token); if (!continue_generation) {