From c70b909b8c64d39b1cc865bb7a958a542514a74e Mon Sep 17 00:00:00 2001 From: Wovchena Date: Thu, 18 Jul 2024 18:29:51 +0400 Subject: [PATCH] Add CB naive chat Merge after https://github.com/openvinotoolkit/openvino.genai/pull/641 --- .../genai/continuous_batching_pipeline.hpp | 12 +++++ src/cpp/src/continuous_batching_pipeline.cpp | 47 ++++++++++++++++--- src/cpp/src/llm_pipeline.cpp | 38 ++++++++++----- tests/python_tests/ov_genai_test_utils.py | 5 ++ tests/python_tests/test_chat_generate_api.py | 20 +++++++- tests/python_tests/test_generate_api.py | 28 ++++------- 6 files changed, 114 insertions(+), 36 deletions(-) diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 43c3f4f802..3ec23b4393 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -67,5 +67,17 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { // more high level interface, which can process multiple prompts in continuous batching manner std::vector generate(const std::vector& input_ids, const std::vector& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{}); std::vector generate(const std::vector& prompts, const std::vector& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{}); + + /** + * @brief start chat with keeping history in kv cache. + * + * @param system_message optional system message. + */ + void start_chat(const std::string& system_message = ""); + + /** + * @brief finish chat and clear kv cache. + */ + void finish_chat(); }; } diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index 08a66ef92f..e3a86e2d84 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -56,6 +56,8 @@ class ContinuousBatchingPipeline::Impl { std::vector m_awaiting_requests; // Mutex protecting access to m_awaiting_requests, so add_request and step methods can be called from different threads std::mutex m_awaiting_requests_mutex; + bool m_is_chat_conversation = false; + ChatHistory m_history; void _free_non_running_requests() { @@ -305,12 +307,22 @@ class ContinuousBatchingPipeline::Impl { std::vector generate(const std::vector& prompts, std::vector sampling_params, const StreamerVariant& streamer) { std::vector input_ids; - input_ids.reserve(prompts.size()); - for (const std::string& prompt : prompts) { - static ManualTimer timer("tokenize"); + static ManualTimer timer("tokenize"); + if (m_is_chat_conversation) { + OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts"); + m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}}); + constexpr bool add_generation_prompt = true; + std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); timer.start(); - input_ids.push_back(m_tokenizer.encode(prompt).input_ids); + input_ids.push_back(m_tokenizer.encode(history).input_ids); timer.end(); + } else { + input_ids.reserve(prompts.size()); + for (const std::string& prompt : prompts) { + timer.start(); + input_ids.push_back(m_tokenizer.encode(prompt).input_ids); + timer.end(); + } } std::vector encoded = generate(input_ids, sampling_params, streamer); std::vector decoded; @@ -318,8 +330,11 @@ class ContinuousBatchingPipeline::Impl { for (EncodedGenerationResult& res : encoded) { std::vector generated; generated.reserve(res.m_generation_ids.size()); - for (const std::vector& tokens : res.m_generation_ids) { - generated.push_back(m_tokenizer.decode(tokens)); + for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) { + generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); + if (m_is_chat_conversation && 0 == idx) { + m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); + } } decoded.push_back(GenerationResult{ res.m_request_id, @@ -330,6 +345,18 @@ class ContinuousBatchingPipeline::Impl { } return decoded; } + + void start_chat(const std::string& system_message) { + if (!system_message.empty()) { + m_history.push_back({{"role", "system"}, {"content", system_message}}); + } + m_is_chat_conversation = true; + }; + + void finish_chat() { + m_is_chat_conversation = false; + m_history.clear(); + }; }; ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path, @@ -382,3 +409,11 @@ std::vector ContinuousBatchingPipeline::generate(const std::vector ContinuousBatchingPipeline::generate(const std::vector& prompts, const std::vector& sampling_params, const StreamerVariant& streamer) { return m_impl->generate(prompts, sampling_params, streamer); } + +void ContinuousBatchingPipeline::start_chat(const std::string& system_message) { + m_impl->start_chat(system_message); +}; + +void ContinuousBatchingPipeline::finish_chat() { + m_impl->finish_chat(); +}; diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index acf7059e7d..1d68d4c746 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -115,6 +115,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { EncodedInputs encoded_input; if (auto input_vector = std::get_if>(&inputs)) { + OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts"); encoded_input = m_tokenizer.encode(*input_vector); } else if (auto input_prompt = std::get_if(&inputs)) { std::string& prompt = *input_prompt; @@ -386,16 +387,31 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { - EncodedInputs input_ids_att = std::visit(overloaded{ - [this](const std::string& prompt) { - return m_tokenizer.encode(prompt); + std::vector prompts = std::visit(overloaded{ + [](const std::string& prompt) { + return std::vector{prompt}; }, - [this](std::vector& prompts) { - return m_tokenizer.encode(prompts); + [](std::vector& prompts) { + return prompts; } }, inputs); - EncodedResults encoded = generate(input_ids_att, generation_config, streamer); - return {m_tokenizer.decode(encoded.tokens), encoded.scores}; + const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config; + // -1 == config.eos_token_id and config.validate() are handled in m_impl. + std::vector generated = m_impl.generate( + prompts, + std::vector{prompts.size(), config}, + streamer + ); + std::vector plain_replies; + std::vector plain_scores; + for (GenerationResult& res : generated) { + if (GenerationStatus::FINISHED != res.m_status) { + OPENVINO_THROW("Got unfinished GenerationStatus"); + } + std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies)); + std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); + } + return {std::move(plain_replies), std::move(plain_scores)}; } EncodedResults generate( @@ -457,12 +473,12 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { } void start_chat(const std::string& system_message) override { - OPENVINO_THROW("start_chat() isn't implemented."); - } + m_impl.start_chat(); + }; void finish_chat() override { - OPENVINO_THROW("finish_chat() isn't implemented."); - } + m_impl.finish_chat(); + }; }; } diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py index 4ba71a1d48..c513353e4a 100644 --- a/tests/python_tests/ov_genai_test_utils.py +++ b/tests/python_tests/ov_genai_test_utils.py @@ -215,3 +215,8 @@ def load_pipe(configs: List[Tuple], temp_path): with (temp_path / config_name).open('w') as f: json.dump(config_json, f) return ov_genai.LLMPipeline(str(temp_path)) + + +@functools.lru_cache(1) +def get_continuous_batching(path): + return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), 'CB') diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py index 94de8f6cc2..814bde076c 100644 --- a/tests/python_tests/test_chat_generate_api.py +++ b/tests/python_tests/test_chat_generate_api.py @@ -1,6 +1,7 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import math import openvino import openvino_tokenizers import openvino_genai as ov_genai @@ -12,7 +13,8 @@ read_model, load_tok, model_tmp_path, - get_chat_templates + get_chat_templates, + get_continuous_batching, ) @@ -163,3 +165,19 @@ def test_apply_chat_template(model_tmp_path, chat_config: Tuple[str, Dict]): print(f'hf reference: {full_history_str_hf}') print(f'ov_genai out: {full_history_str}') assert full_history_str == full_history_str_hf + + +@pytest.mark.parametrize("generation_config", configs[1:]) +@pytest.mark.parametrize("model_descr", get_chat_models_list()) +@pytest.mark.precommit +def test_chat_continuous_batching_vs_stateful(model_descr, generation_config: Dict): + model_id, path, tokenizer, model, stateful = read_model(model_descr) + cb = get_continuous_batching(path) + stateful.start_chat() + cb.start_chat() + for question in quenstions: + generated = cb.generate(question, **generation_config) + reference = stateful.generate(question, **generation_config) + assert generated == reference + # Test that finish_chat() doesn't fail just in case. + cb.finish_chat() diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index a796aa07e1..6b859b05d5 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -11,7 +11,6 @@ import sys from pathlib import Path import torch -import functools import math from ov_genai_test_utils import ( get_models_list, @@ -20,6 +19,7 @@ load_tok, model_tmp_path, STOP_CRITERIA_MAP, + get_continuous_batching, ) @@ -675,39 +675,31 @@ def test_left_pad(): run_hf_ov_genai_comparison_batched(models, config, prompts) -@functools.lru_cache(1) -def get_continuous_batching(path): - return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), 'CB') - - @pytest.mark.parametrize("generation_config", test_configs) @pytest.mark.parametrize("prompt", batched_prompts) +@pytest.mark.parametrize("model_descr", get_models_list()) @pytest.mark.precommit -def test_continuous_batching_vs_stateful(prompt, generation_config): - model_id, path, tokenizer, model, stateful = read_model(( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - Path("TinyLlama-1.1B-Chat-v1.0") - )) +def test_continuous_batching_vs_stateful(model_descr, prompt, generation_config): + model_id, path, tokenizer, model, stateful = read_model(model_descr) config = ov_genai.GenerationConfig() config.max_new_tokens = 100 cb = get_continuous_batching(path) generated = cb.generate(prompt, **generation_config) reference = stateful.generate(prompt, **generation_config) assert generated.texts == reference.texts - if 1 != generation_config.get("num_beams", 1): + if 1 != generation_config.get("num_return_sequences", 1): # Stateful puts zeroes to generated.scores. Don't compare them. for gen, ref in zip(generated.scores, reference.scores): assert math.isclose(gen, ref, abs_tol=0.0003) @pytest.mark.parametrize("prompt", prompts) +@pytest.mark.parametrize("model_descr", get_models_list()) @pytest.mark.precommit -def test_cb_streamer_vs_return_vs_stateful(prompt): - model_id, path, tokenizer, model, stateful = read_model(( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - Path("TinyLlama-1.1B-Chat-v1.0") - )) +def test_cb_streamer_vs_return_vs_stateful(model_descr, prompt): + model_id, path, tokenizer, model, stateful = read_model(model_descr) cb = get_continuous_batching(path) streamed = [] generated = cb.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword)) reference = stateful.generate(prompt, max_new_tokens=20) - assert generated == "".join(streamed) == reference + assert generated == "".join(streamed) + assert "".join(streamed) == reference