Skip to content

Commit

Permalink
add return bool to streamer to stop generation
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed May 31, 2024
1 parent 7021c87 commit 680e362
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/genai_python_lib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
echo "$models" | while read -r model_name model_path; do
optimum-cli export openvino --trust-remote-code --weight-format fp16 --model "$model_name" "$model_path"
done
GENAI_BUILD_DIR=../../build python -m pytest test_generate_api.py -v
GENAI_BUILD_DIR=../../build python -m pytest test_generate_api.py -v -m precommit
windows_genai_python_lib:
runs-on: windows-latest
Expand Down
4 changes: 3 additions & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ Streaming with a custom class

class CustomStreamer: public ov::genai::StreamerBase {
public:
void put(int64_t token) {
bool put(int64_t token) {
bool stop_flag = false;
/* custom decoding/tokens processing code
tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(tokens_cache);
...
*/
return stop_flag;
};

void end() {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace ov {
namespace genai {

using StreamerVariant = std::variant<std::function<void(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, std::pair<ov::Tensor, ov::Tensor>, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ namespace genai {
*/
class StreamerBase {
public:
/// @brief put is called every time new token is decoded
virtual void put(int64_t token) = 0;
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stoped, if return true generation stops
virtual bool put(int64_t token) = 0;

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;
Expand Down
13 changes: 8 additions & 5 deletions src/cpp/src/greedy_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ EncodedResults greedy_decoding(
eos_met[batch] = (out_token == generation_config.eos_token_id);
m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
if (streamer)
streamer->put(token_iter_results[0]);
if (streamer && streamer->put(token_iter_results[0])) {
return results;
}

bool all_are_eos = std::all_of(eos_met.begin(), eos_met.end(), [](int elem) { return elem == 1; });
if (!generation_config.ignore_eos && all_are_eos)
Expand All @@ -107,8 +108,9 @@ EncodedResults greedy_decoding(

m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
if (streamer)
streamer->put(token_iter_results[0]);
if (streamer && streamer->put(token_iter_results[0])) {
return results;
}

// Filter out the eos met batches
std::vector<int32_t> beam_idx(running_batch_size);
Expand All @@ -126,8 +128,9 @@ EncodedResults greedy_decoding(
if (!generation_config.ignore_eos && all_are_eos)
break;
}
if (streamer)
if (streamer) {
streamer->end();
}
return results;
}

Expand Down
10 changes: 5 additions & 5 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) {
auto any_val = config_map.at(STREAMER_ARG_NAME);
if (any_val.is<std::shared_ptr<ov::genai::StreamerBase>>()) {
streamer = any_val.as<std::shared_ptr<ov::genai::StreamerBase>>();
} else if (any_val.is<std::function<void(std::string)>>()) {
streamer = any_val.as<std::function<void(std::string)>>();
} else if (any_val.is<std::function<bool(std::string)>>()) {
streamer = any_val.as<std::function<bool(std::string)>>();
}
}
return streamer;
Expand Down Expand Up @@ -227,7 +227,7 @@ class LLMPipeline::LLMPipelineImpl {
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<void(std::string)>>(&streamer)) {
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

Expand Down Expand Up @@ -316,8 +316,8 @@ std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<void(std::string)>>(func);
return {STREAMER_ARG_NAME, Any::make<std::function<void(std::string)>>(callback)};
auto callback = std::get<std::function<bool(std::string)>>(func);
return {STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/cpp/src/multinomial_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner
results.tokens[0].push_back(out_token.id);
results.scores[0] += out_token.score;

if (streamer) {
streamer->put(out_token.id);
if (streamer && streamer->put(out_token.id)) {
return results;
}

if (!config.ignore_eos && out_token.id == config.eos_token_id) {
Expand Down Expand Up @@ -242,10 +242,10 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner
results.tokens[0].push_back(out_token.id);
results.scores[0] += out_token.score;

if (streamer) {
streamer->put(out_token.id);
if (streamer && streamer->put(out_token.id)) {
return results;
}

if (!config.ignore_eos && out_token.id == config.eos_token_id) {
break;
}
Expand Down
27 changes: 9 additions & 18 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
namespace ov {
namespace genai {

TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function<void (std::string)> callback, bool print_eos_token) {
TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback, bool print_eos_token) {
m_tokenizer = tokenizer;
m_print_eos_token = print_eos_token;
on_decoded_text_callback = callback;
m_enabled = true;
on_finalized_subword_callback = callback;
}

void TextCallbackStreamer::put(int64_t token) {
bool TextCallbackStreamer::put(int64_t token) {
std::stringstream res;
// do nothing if <eos> token is met and if print_eos_token=false
if (!m_print_eos_token && token == m_tokenizer.get_eos_token_id())
return;
return false;

m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);
Expand All @@ -26,18 +25,15 @@ void TextCallbackStreamer::put(int64_t token) {
res << std::string_view{text.data() + print_len, text.size() - print_len};
m_tokens_cache.clear();
print_len = 0;
on_finalized_text(res.str());
return;
return on_finalized_subword_callback(res.str());
}
if (text.size() >= 3 && text.compare(text.size() - 3, 3, "") == 0) {
// Don't print incomplete text
on_finalized_text(res.str());
return;
return on_finalized_subword_callback(res.str());
}
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
print_len = text.size();
on_finalized_text(res.str());
return;
return on_finalized_subword_callback(res.str());
}

void TextCallbackStreamer::end() {
Expand All @@ -46,13 +42,8 @@ void TextCallbackStreamer::end() {
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
m_tokens_cache.clear();
print_len = 0;
on_finalized_text(res.str());
}

void TextCallbackStreamer::on_finalized_text(const std::string& subword) {
if (m_enabled) {
on_decoded_text_callback(subword);
}
on_finalized_subword_callback(res.str());
return;
}

} // namespace genai
Expand Down
8 changes: 3 additions & 5 deletions src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@ namespace genai {

class TextCallbackStreamer: public StreamerBase {
public:
void put(int64_t token) override;
bool put(int64_t token) override;
void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<void(std::string)> callback, bool print_eos_token = false);
TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback, bool print_eos_token = false);

std::function<void (std::string)> on_decoded_text_callback = [](std::string words){};
bool m_enabled = false;
std::function<bool(std::string)> on_finalized_subword_callback = [](std::string words)->bool { return false; };
int64_t m_eos_token;
private:
bool m_print_eos_token = false;
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
void on_finalized_text(const std::string& subword);
};

} // namespace genai
Expand Down
4 changes: 2 additions & 2 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ std::string ov_tokenizers_module_path() {

class EmptyStreamer: public StreamerBase {
// It's impossible to create an instance of pure virtual class. Define EmptyStreamer instead.
void put(int64_t token) override {
bool put(int64_t token) override {
PYBIND11_OVERRIDE_PURE(
void, // Return type
bool, // Return type
StreamerBase, // Parent class
put, // Name of function in C++ (must match Python name)
token // Argument(s)
Expand Down
34 changes: 26 additions & 8 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def stop_criteria_map():

test_cases = [
(dict(max_new_tokens=20), 'table is made of'), # generation_config, prompt
(dict(max_new_tokens=20), '你好! 你好嗎?'), # generation_config, prompt
(dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=20, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(num_beam_groups=3, num_beams=15, num_return_sequences=15, max_new_tokens=30, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'table is made of'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.0), 'The Sun is yellow because'),
(dict(num_beam_groups=2, num_beams=8, num_return_sequences=8, max_new_tokens=20, diversity_penalty=1.5), 'The Sun is yellow because'),
]
@pytest.mark.parametrize("generation_config,prompt", test_cases)
@pytest.mark.precommit
def test_decoding(model_fixture, generation_config, prompt):
run_hf_ov_genai_comparison(model_fixture, generation_config, prompt)

Expand All @@ -134,21 +136,24 @@ def test_decoding(model_fixture, generation_config, prompt):
dict(max_new_tokens=20),
dict( max_new_tokens=20, num_beam_groups=3, num_beams=15,diversity_penalty=1.0)
]
batched_prompts = [['table is made of', 'They sky is blue because', 'Difference between Jupiter and Marks is that'],
['hello', 'Here is the longest nowel ever: ']]
batched_prompts = [['table is made of', 'They sky is blue because', 'Difference between Jupiter and Mars is that'],
['hello', 'Here is the longest nowel ever: '],
['Alan Turing was a', 'return 0', '你好! 你好嗎?']]
@pytest.mark.parametrize("generation_config", test_configs)
@pytest.mark.parametrize("prompts", batched_prompts)
@pytest.mark.precommit
def test_multibatch(model_fixture, generation_config, prompts):
generation_config['pad_token_id'] = 2
run_hf_ov_genai_comparison_batched(model_fixture, generation_config, prompts)


prompts = ['The Sun is yellow because', 'Difference between Jupiter and Marks is that', 'table is made of']
prompts = ['The Sun is yellow because', 'Difference between Jupiter and Mars is that', 'table is made of']
@pytest.mark.parametrize("num_beam_groups", [2, 3, 8])
@pytest.mark.parametrize("group_size", [5, 3, 10])
@pytest.mark.parametrize("max_new_tokens", [20, 15])
@pytest.mark.parametrize("diversity_penalty", [1.0 , 1.5])
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.precommit
def test_beam_search_decoding(model_fixture, num_beam_groups, group_size,
max_new_tokens, diversity_penalty, prompt):
generation_config = dict(
Expand All @@ -164,6 +169,7 @@ def test_beam_search_decoding(model_fixture, num_beam_groups, group_size,
@pytest.mark.parametrize("stop_criteria", [StopCriteria.NEVER, StopCriteria.EARLY, StopCriteria.HEURISTIC])
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.parametrize("max_new_tokens", [10, 80])
@pytest.mark.precommit
def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens):
# todo: with EARLY stop_criteria looks like HF return unvalid out with sentence<eos><unk><unk>
# while genai ends sentence with <eos>
Expand All @@ -185,7 +191,7 @@ def test_stop_criteria(model_fixture, stop_criteria, prompt, max_new_tokens):
@pytest.mark.parametrize("group_size", [5])
@pytest.mark.parametrize("max_new_tokens", [800, 2000])
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.skip # will be enabled in nightly since are computationally expensive
@pytest.mark.nightly
def test_beam_search_long_sentences(model_fixture, num_beam_groups, group_size,
max_new_tokens, prompt):
generation_config = dict(
Expand All @@ -203,25 +209,29 @@ def user_defined_callback(subword):


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
def test_callback_one_string(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
pipe.generate('', openvino_genai.GenerationConfig(), callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
def test_callback_batch_fail(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
with pytest.raises(RuntimeError):
pipe.generate(['1', '2'], openvino_genai.GenerationConfig(), callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
def test_callback_kwargs_one_string(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
pipe.generate('', max_new_tokens=10, streamer=callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
def test_callback_kwargs_batch_fail(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
with pytest.raises(RuntimeError):
Expand All @@ -239,52 +249,60 @@ def end(self):
print('end')


@pytest.mark.precommit
def test_streamer_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe.generate('', openvino_genai.GenerationConfig(), printer)


@pytest.mark.precommit
def test_streamer_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
pipe.generate(['1', '2'], openvino_genai.GenerationConfig(), printer)


@pytest.mark.precommit
def test_streamer_kwargs_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe.generate('', do_sample=True, streamer=printer)


@pytest.mark.precommit
def test_streamer_kwargs_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
pipe.generate('', num_beams=2, streamer=printer)


@pytest.mark.precommit
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_wit_callback_one_string(model_fixture, callback):
def test_operator_with_callback_one_string(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
pipe('', openvino_genai.GenerationConfig(), callback)


@pytest.mark.precommit
@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_wit_callback_batch_fail(model_fixture, callback):
def test_operator_with_callback_batch_fail(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
with pytest.raises(Exception):
pipe(['1', '2'], openvino_genai.GenerationConfig(), callback)


def test_operator_wit_streamer_kwargs_one_string(model_fixture):
@pytest.mark.precommit
def test_operator_with_streamer_kwargs_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe('', do_sample=True, streamer=printer)


def test_operator_wit_streamer_kwargs_batch_fail(model_fixture):
@pytest.mark.precommit
def test_operator_with_streamer_kwargs_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
Expand Down
2 changes: 1 addition & 1 deletion text_generation/causal_lm/cpp/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config = pipe.get_generation_config();
config.max_new_tokens = 10000;
std::function<void(std::string)> streamer = [](std::string word) { std::cout << word << std::flush; };
std::function<bool(std::string)> streamer = [](std::string word) { std::cout << word << std::flush; return false; };

pipe.start_chat();
for (;;) {
Expand Down
Loading

0 comments on commit 680e362

Please sign in to comment.