From 4e36982b5089e8c00e0c6aed04609f72244e1cb8 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Wed, 4 Dec 2024 18:12:46 +0100 Subject: [PATCH 1/3] fix TextCallbackStreamer when characters are eaten by regex after apostrophe --- src/cpp/src/text_callback_streamer.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index cc7b7ff31f..116c62d6c4 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -15,6 +15,14 @@ bool TextCallbackStreamer::put(int64_t token) { std::stringstream res; m_tokens_cache.push_back(token); std::string text = m_tokenizer.decode(m_tokens_cache); + + // In some cases adding the next token can shorten the text, + // e.g. when apostrophe removing regex had worked after adding new tokens. + // Need to hold on before flushing if apostrophe is the last symbol. + if (text.size() > 1 && text.back() == '\'') { + return on_finalized_subword_callback(res.str()); + } + if (!text.empty() && '\n' == text.back() && text.size() > print_len) { // Flush the cache after the new line symbol res << std::string_view{text.data() + print_len, text.size() - print_len}; From af05de5161c4fe2c0f252739f7d3c3146fad88b7 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Wed, 4 Dec 2024 20:35:21 +0100 Subject: [PATCH 2/3] always delay last N characters --- .../multinomial_causal_lm.py | 10 +++++++--- src/cpp/src/text_callback_streamer.cpp | 17 +++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/samples/python/multinomial_causal_lm/multinomial_causal_lm.py b/samples/python/multinomial_causal_lm/multinomial_causal_lm.py index da51176d06..78a14e1e1d 100755 --- a/samples/python/multinomial_causal_lm/multinomial_causal_lm.py +++ b/samples/python/multinomial_causal_lm/multinomial_causal_lm.py @@ -85,6 +85,7 @@ def put(self, token_id: int) -> bool: text = self.tokenizer.decode(self.tokens_cache) word = '' + delay_n_chars = 4 if len(text) > self.print_len and '\n' == text[-1]: # Flush the cache after the new line symbol. word = text[self.print_len:] @@ -93,11 +94,14 @@ def put(self, token_id: int) -> bool: elif len(text) >= 3 and text[-3:] == chr(65533): # Don't print incomplete text. pass - elif len(text) > self.print_len: + elif len(text) > self.print_len + delay_n_chars: # It is possible to have a shorter text after adding new token. # Print to output only if text length is increaesed. - word = text[self.print_len:] - self.print_len = len(text) + # Also, in some cases adding the next token can shorten the text, + # e.g. when apostrophe removing regex had worked after adding new tokens. + # Several last characters are delayed before flushed to output. + word = text[self.print_len:-delay_n_chars] + self.print_len = len(text) - delay_n_chars self.put_word(word) if self.get_stop_flag(): diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 116c62d6c4..28bd1e95a0 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -16,13 +16,6 @@ bool TextCallbackStreamer::put(int64_t token) { m_tokens_cache.push_back(token); std::string text = m_tokenizer.decode(m_tokens_cache); - // In some cases adding the next token can shorten the text, - // e.g. when apostrophe removing regex had worked after adding new tokens. - // Need to hold on before flushing if apostrophe is the last symbol. - if (text.size() > 1 && text.back() == '\'') { - return on_finalized_subword_callback(res.str()); - } - if (!text.empty() && '\n' == text.back() && text.size() > print_len) { // Flush the cache after the new line symbol res << std::string_view{text.data() + print_len, text.size() - print_len}; @@ -31,15 +24,19 @@ bool TextCallbackStreamer::put(int64_t token) { return on_finalized_subword_callback(res.str()); } + // In some cases adding the next token can shorten the text, + // e.g. when apostrophe removing regex had worked after adding new tokens. + // Several last characters are delayed before flushed to output. + constexpr size_t delay_n_chars = 4; constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { // Don't print incomplete text return on_finalized_subword_callback(res.str()); - } else if (text.size() > print_len) { + } else if (text.size() > print_len + delay_n_chars) { // It is possible to have a shorter text after adding new token. // Print to output only if text length is increaesed. - res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; - print_len = text.size(); + res << std::string_view{text.data() + print_len, text.size() - print_len - delay_n_chars} << std::flush; + print_len = text.size() - delay_n_chars; } return on_finalized_subword_callback(res.str()); From 2a0f52107844926d9b51a826223ca33a024afc4d Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Wed, 4 Dec 2024 21:50:46 +0100 Subject: [PATCH 3/3] update test for the new behaviour --- tests/python_tests/test_generate_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index 3291558407..e69369e516 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -687,7 +687,7 @@ def test_unicode_pybind_decoding_3(): model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3') pipe = read_model((model_id, path))[4] res_str = [] - pipe.generate(",", max_new_tokens=4, streamer=lambda x: res_str.append(x)) + pipe.generate(",", max_new_tokens=4, streamer=lambda x: res_str.extend(list(x))) assert '�' == res_str[-1]