From 680e3623c2200b801f9968df61532c9ee36d96ca Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 31 May 2024 10:29:21 +0200 Subject: [PATCH] add return bool to streamer to stop generation --- .github/workflows/genai_python_lib.yml | 2 +- src/README.md | 4 ++- .../include/openvino/genai/llm_pipeline.hpp | 2 +- .../include/openvino/genai/streamer_base.hpp | 5 +-- src/cpp/src/greedy_decoding.cpp | 13 ++++--- src/cpp/src/llm_pipeline.cpp | 10 +++--- src/cpp/src/multinomial_decoding.cpp | 10 +++--- src/cpp/src/text_callback_streamer.cpp | 27 +++++---------- src/cpp/src/text_callback_streamer.hpp | 8 ++--- src/python/py_generate_pipeline.cpp | 4 +-- tests/python_tests/test_generate_api.py | 34 ++++++++++++++----- text_generation/causal_lm/cpp/chat_sample.cpp | 2 +- .../causal_lm/cpp/greedy_causal_lm.cpp | 2 +- .../causal_lm/cpp/multinomial_causal_lm.cpp | 1 + 14 files changed, 69 insertions(+), 55 deletions(-) diff --git a/.github/workflows/genai_python_lib.yml b/.github/workflows/genai_python_lib.yml index b70e5dff90..e2227adfaa 100644 --- a/.github/workflows/genai_python_lib.yml +++ b/.github/workflows/genai_python_lib.yml @@ -34,7 +34,7 @@ jobs: echo "$models" | while read -r model_name model_path; do optimum-cli export openvino --trust-remote-code --weight-format fp16 --model "$model_name" "$model_path" done - GENAI_BUILD_DIR=../../build python -m pytest test_generate_api.py -v + GENAI_BUILD_DIR=../../build python -m pytest test_generate_api.py -v -m precommit windows_genai_python_lib: runs-on: windows-latest diff --git a/src/README.md b/src/README.md index 06a649a752..854908684c 100644 --- a/src/README.md +++ b/src/README.md @@ -140,12 +140,14 @@ Streaming with a custom class class CustomStreamer: public ov::genai::StreamerBase { public: - void put(int64_t token) { + bool put(int64_t token) { + bool stop_flag = false; /* custom decoding/tokens processing code tokens_cache.push_back(token); std::string text = m_tokenizer.decode(tokens_cache); ... */ + return stop_flag; }; void end() { diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 73905616e5..9ee9d8d1b7 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -14,7 +14,7 @@ namespace ov { namespace genai { -using StreamerVariant = std::variant, std::shared_ptr, std::monostate>; +using StreamerVariant = std::variant, std::shared_ptr, std::monostate>; using OptionalGenerationConfig = std::optional; using EncodedInputs = std::variant, TokenizedInputs>; using StringInputs = std::variant>; diff --git a/src/cpp/include/openvino/genai/streamer_base.hpp b/src/cpp/include/openvino/genai/streamer_base.hpp index ba6287c66a..04d350cc5d 100644 --- a/src/cpp/include/openvino/genai/streamer_base.hpp +++ b/src/cpp/include/openvino/genai/streamer_base.hpp @@ -15,8 +15,9 @@ namespace genai { */ class StreamerBase { public: - /// @brief put is called every time new token is decoded - virtual void put(int64_t token) = 0; + /// @brief put is called every time new token is decoded, + /// @return bool flag to indicate whether generation should be stoped, if return true generation stops + virtual bool put(int64_t token) = 0; /// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one virtual void end() = 0; diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index 48cedf09f0..f68df0bf6b 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -79,8 +79,9 @@ EncodedResults greedy_decoding( eos_met[batch] = (out_token == generation_config.eos_token_id); m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - if (streamer) - streamer->put(token_iter_results[0]); + if (streamer && streamer->put(token_iter_results[0])) { + return results; + } bool all_are_eos = std::all_of(eos_met.begin(), eos_met.end(), [](int elem) { return elem == 1; }); if (!generation_config.ignore_eos && all_are_eos) @@ -107,8 +108,9 @@ EncodedResults greedy_decoding( m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - if (streamer) - streamer->put(token_iter_results[0]); + if (streamer && streamer->put(token_iter_results[0])) { + return results; + } // Filter out the eos met batches std::vector beam_idx(running_batch_size); @@ -126,8 +128,9 @@ EncodedResults greedy_decoding( if (!generation_config.ignore_eos && all_are_eos) break; } - if (streamer) + if (streamer) { streamer->end(); + } return results; } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 10b7da499a..df0f007445 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -66,8 +66,8 @@ ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) { auto any_val = config_map.at(STREAMER_ARG_NAME); if (any_val.is>()) { streamer = any_val.as>(); - } else if (any_val.is>()) { - streamer = any_val.as>(); + } else if (any_val.is>()) { + streamer = any_val.as>(); } } return streamer; @@ -227,7 +227,7 @@ class LLMPipeline::LLMPipelineImpl { streamer_ptr = nullptr; } else if (auto streamer_obj = std::get_if>(&streamer)) { streamer_ptr = *streamer_obj; - } else if (auto callback = std::get_if>(&streamer)) { + } else if (auto callback = std::get_if>(&streamer)) { streamer_ptr = std::make_shared(m_tokenizer, *callback); } @@ -316,8 +316,8 @@ std::pair streamer(StreamerVariant func) { if (auto streamer_obj = std::get_if>(&func)) { return {STREAMER_ARG_NAME, Any::make>(*streamer_obj)}; } else { - auto callback = std::get>(func); - return {STREAMER_ARG_NAME, Any::make>(callback)}; + auto callback = std::get>(func); + return {STREAMER_ARG_NAME, Any::make>(callback)}; } } diff --git a/src/cpp/src/multinomial_decoding.cpp b/src/cpp/src/multinomial_decoding.cpp index 33b7e5e378..de9bd3cd12 100644 --- a/src/cpp/src/multinomial_decoding.cpp +++ b/src/cpp/src/multinomial_decoding.cpp @@ -212,8 +212,8 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner results.tokens[0].push_back(out_token.id); results.scores[0] += out_token.score; - if (streamer) { - streamer->put(out_token.id); + if (streamer && streamer->put(out_token.id)) { + return results; } if (!config.ignore_eos && out_token.id == config.eos_token_id) { @@ -242,10 +242,10 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner results.tokens[0].push_back(out_token.id); results.scores[0] += out_token.score; - if (streamer) { - streamer->put(out_token.id); + if (streamer && streamer->put(out_token.id)) { + return results; } - + if (!config.ignore_eos && out_token.id == config.eos_token_id) { break; } diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 4169967bc3..9361cadca7 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -6,18 +6,17 @@ namespace ov { namespace genai { -TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, bool print_eos_token) { +TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, bool print_eos_token) { m_tokenizer = tokenizer; m_print_eos_token = print_eos_token; - on_decoded_text_callback = callback; - m_enabled = true; + on_finalized_subword_callback = callback; } -void TextCallbackStreamer::put(int64_t token) { +bool TextCallbackStreamer::put(int64_t token) { std::stringstream res; // do nothing if token is met and if print_eos_token=false if (!m_print_eos_token && token == m_tokenizer.get_eos_token_id()) - return; + return false; m_tokens_cache.push_back(token); std::string text = m_tokenizer.decode(m_tokens_cache); @@ -26,18 +25,15 @@ void TextCallbackStreamer::put(int64_t token) { res << std::string_view{text.data() + print_len, text.size() - print_len}; m_tokens_cache.clear(); print_len = 0; - on_finalized_text(res.str()); - return; + return on_finalized_subword_callback(res.str()); } if (text.size() >= 3 && text.compare(text.size() - 3, 3, "�") == 0) { // Don't print incomplete text - on_finalized_text(res.str()); - return; + return on_finalized_subword_callback(res.str()); } res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; print_len = text.size(); - on_finalized_text(res.str()); - return; + return on_finalized_subword_callback(res.str()); } void TextCallbackStreamer::end() { @@ -46,13 +42,8 @@ void TextCallbackStreamer::end() { res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; m_tokens_cache.clear(); print_len = 0; - on_finalized_text(res.str()); -} - -void TextCallbackStreamer::on_finalized_text(const std::string& subword) { - if (m_enabled) { - on_decoded_text_callback(subword); - } + on_finalized_subword_callback(res.str()); + return; } } // namespace genai diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index 76e8fdb37c..0f2462a619 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -11,20 +11,18 @@ namespace genai { class TextCallbackStreamer: public StreamerBase { public: - void put(int64_t token) override; + bool put(int64_t token) override; void end() override; - TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, bool print_eos_token = false); + TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, bool print_eos_token = false); - std::function on_decoded_text_callback = [](std::string words){}; - bool m_enabled = false; + std::function on_finalized_subword_callback = [](std::string words)->bool { return false; }; int64_t m_eos_token; private: bool m_print_eos_token = false; Tokenizer m_tokenizer; std::vector m_tokens_cache; size_t print_len = 0; - void on_finalized_text(const std::string& subword); }; } // namespace genai diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index 3b9b80897f..12cc3136bb 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -180,9 +180,9 @@ std::string ov_tokenizers_module_path() { class EmptyStreamer: public StreamerBase { // It's impossible to create an instance of pure virtual class. Define EmptyStreamer instead. - void put(int64_t token) override { + bool put(int64_t token) override { PYBIND11_OVERRIDE_PURE( - void, // Return type + bool, // Return type StreamerBase, // Parent class put, // Name of function in C++ (must match Python name) token // Argument(s) diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index c8bcf25a73..953059fcaa 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -119,6 +119,7 @@ def stop_criteria_map(): test_cases = [ (dict(max_new_tokens=20), 'table is made of'), # generation_config, prompt + (dict(max_new_tokens=20), '你好! 你好嗎?'), # generation_config, prompt (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'), (dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'), (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'), @@ -126,6 +127,7 @@ def stop_criteria_map(): (dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'), ] @pytest.mark.parametrize("generation_config,prompt", test_cases) +@pytest.mark.precommit def test_decoding(model_fixture, generation_config, prompt): run_hf_ov_genai_comparison(model_fixture, generation_config, prompt) @@ -134,21 +136,24 @@ def test_decoding(model_fixture, generation_config, prompt): dict(max_new_tokens=20), dict( max_new_tokens=20, num_beam_groups=3, num_beams=15,diversity_penalty=1.0) ] -batched_prompts = [['table is made of', 'They sky is blue because', 'Difference between Jupiter and Marks is that'], - ['hello', 'Here is the longest nowel ever: ']] +batched_prompts = [['table is made of', 'They sky is blue because', 'Difference between Jupiter and Mars is that'], + ['hello', 'Here is the longest nowel ever: '], + ['Alan Turing was a', 'return 0', '你好! 你好嗎?']] @pytest.mark.parametrize("generation_config", test_configs) @pytest.mark.parametrize("prompts", batched_prompts) +@pytest.mark.precommit def test_multibatch(model_fixture, generation_config, prompts): generation_config['pad_token_id'] = 2 run_hf_ov_genai_comparison_batched(model_fixture, generation_config, prompts) -prompts = ['The Sun is yellow because', 'Difference between Jupiter and Marks is that', 'table is made of'] +prompts = ['The Sun is yellow because', 'Difference between Jupiter and Mars is that', 'table is made of'] @pytest.mark.parametrize("num_beam_groups", [2, 3, 8]) @pytest.mark.parametrize("group_size", [5, 3, 10]) @pytest.mark.parametrize("max_new_tokens", [20, 15]) @pytest.mark.parametrize("diversity_penalty", [1.0 , 1.5]) @pytest.mark.parametrize("prompt", prompts) +@pytest.mark.precommit def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, max_new_tokens, diversity_penalty, prompt): generation_config = dict( @@ -164,6 +169,7 @@ def test_beam_search_decoding(model_fixture, num_beam_groups, group_size, @pytest.mark.parametrize("stop_criteria", [StopCriteria.NEVER, StopCriteria.EARLY, StopCriteria.HEURISTIC]) @pytest.mark.parametrize("prompt", prompts) @pytest.mark.parametrize("max_new_tokens", [10, 80]) +@pytest.mark.precommit def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens): # todo: with EARLY stop_criteria looks like HF return unvalid out with sentence # while genai ends sentence with @@ -185,7 +191,7 @@ def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens): @pytest.mark.parametrize("group_size", [5]) @pytest.mark.parametrize("max_new_tokens", [800, 2000]) @pytest.mark.parametrize("prompt", prompts) -@pytest.mark.skip # will be enabled in nightly since are computationally expensive +@pytest.mark.nightly def test_beam_search_long_sentences(model_fixture, num_beam_groups, group_size, max_new_tokens, prompt): generation_config = dict( @@ -203,12 +209,14 @@ def user_defined_callback(subword): @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.precommit def test_callback_one_string(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') pipe.generate('', openvino_genai.GenerationConfig(), callback) @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.precommit def test_callback_batch_fail(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') with pytest.raises(RuntimeError): @@ -216,12 +224,14 @@ def test_callback_batch_fail(model_fixture, callback): @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.precommit def test_callback_kwargs_one_string(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') pipe.generate('', max_new_tokens=10, streamer=callback) @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.precommit def test_callback_kwargs_batch_fail(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') with pytest.raises(RuntimeError): @@ -239,12 +249,14 @@ def end(self): print('end') +@pytest.mark.precommit def test_streamer_one_string(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) pipe.generate('', openvino_genai.GenerationConfig(), printer) +@pytest.mark.precommit def test_streamer_batch_fail(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) @@ -252,12 +264,14 @@ def test_streamer_batch_fail(model_fixture): pipe.generate(['1', '2'], openvino_genai.GenerationConfig(), printer) +@pytest.mark.precommit def test_streamer_kwargs_one_string(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) pipe.generate('', do_sample=True, streamer=printer) +@pytest.mark.precommit def test_streamer_kwargs_batch_fail(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) @@ -265,26 +279,30 @@ def test_streamer_kwargs_batch_fail(model_fixture): pipe.generate('', num_beams=2, streamer=printer) +@pytest.mark.precommit @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) -def test_operator_wit_callback_one_string(model_fixture, callback): +def test_operator_with_callback_one_string(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') pipe('', openvino_genai.GenerationConfig(), callback) +@pytest.mark.precommit @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) -def test_operator_wit_callback_batch_fail(model_fixture, callback): +def test_operator_with_callback_batch_fail(model_fixture, callback): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') with pytest.raises(Exception): pipe(['1', '2'], openvino_genai.GenerationConfig(), callback) -def test_operator_wit_streamer_kwargs_one_string(model_fixture): +@pytest.mark.precommit +def test_operator_with_streamer_kwargs_one_string(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) pipe('', do_sample=True, streamer=printer) -def test_operator_wit_streamer_kwargs_batch_fail(model_fixture): +@pytest.mark.precommit +def test_operator_with_streamer_kwargs_batch_fail(model_fixture): pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU') printer = Printer(pipe.get_tokenizer()) with pytest.raises(RuntimeError): diff --git a/text_generation/causal_lm/cpp/chat_sample.cpp b/text_generation/causal_lm/cpp/chat_sample.cpp index a288f1f23b..75cf609afe 100644 --- a/text_generation/causal_lm/cpp/chat_sample.cpp +++ b/text_generation/causal_lm/cpp/chat_sample.cpp @@ -12,7 +12,7 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config = pipe.get_generation_config(); config.max_new_tokens = 10000; - std::function streamer = [](std::string word) { std::cout << word << std::flush; }; + std::function streamer = [](std::string word) { std::cout << word << std::flush; return false; }; pipe.start_chat(); for (;;) { diff --git a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp index 0fea9b36d3..dd309af8f9 100644 --- a/text_generation/causal_lm/cpp/greedy_causal_lm.cpp +++ b/text_generation/causal_lm/cpp/greedy_causal_lm.cpp @@ -18,7 +18,7 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config = pipe.get_generation_config(); config.max_new_tokens = 100; config.do_sample = false; - auto streamer = [](std::string subword){std::cout << subword << std::flush;}; + auto streamer = [](std::string subword){ std::cout << subword << std::flush; return false; }; // since streamer is set results will be printed each time a new token is generated pipe.generate(prompt, config, streamer); diff --git a/text_generation/causal_lm/cpp/multinomial_causal_lm.cpp b/text_generation/causal_lm/cpp/multinomial_causal_lm.cpp index ffbfc6b2c3..6cbab7d2f5 100644 --- a/text_generation/causal_lm/cpp/multinomial_causal_lm.cpp +++ b/text_generation/causal_lm/cpp/multinomial_causal_lm.cpp @@ -24,6 +24,7 @@ int main(int argc, char* argv[]) try { config.top_k = 30; auto streamer = [](std::string subword) { std::cout << subword << std::flush; + return false; }; // since streamer is set results will be printed each time a new token is generated