Skip to content

Commit

Permalink
always delay last N characters
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Dec 4, 2024
1 parent 4e36982 commit af05de5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
10 changes: 7 additions & 3 deletions samples/python/multinomial_causal_lm/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def put(self, token_id: int) -> bool:
text = self.tokenizer.decode(self.tokens_cache)

word = ''
delay_n_chars = 4
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
Expand All @@ -93,11 +94,14 @@ def put(self, token_id: int) -> bool:
elif len(text) >= 3 and text[-3:] == chr(65533):
# Don't print incomplete text.
pass
elif len(text) > self.print_len:
elif len(text) > self.print_len + delay_n_chars:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increaesed.
word = text[self.print_len:]
self.print_len = len(text)
# Also, in some cases adding the next token can shorten the text,
# e.g. when apostrophe removing regex had worked after adding new tokens.
# Several last characters are delayed before flushed to output.
word = text[self.print_len:-delay_n_chars]
self.print_len = len(text) - delay_n_chars
self.put_word(word)

if self.get_stop_flag():
Expand Down
17 changes: 7 additions & 10 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ bool TextCallbackStreamer::put(int64_t token) {
m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);

// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Need to hold on before flushing if apostrophe is the last symbol.
if (text.size() > 1 && text.back() == '\'') {
return on_finalized_subword_callback(res.str());
}

if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
Expand All @@ -31,15 +24,19 @@ bool TextCallbackStreamer::put(int64_t token) {
return on_finalized_subword_callback(res.str());
}

// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Several last characters are delayed before flushed to output.
constexpr size_t delay_n_chars = 4;
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else if (text.size() > print_len) {
} else if (text.size() > print_len + delay_n_chars) {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
print_len = text.size();
res << std::string_view{text.data() + print_len, text.size() - print_len - delay_n_chars} << std::flush;
print_len = text.size() - delay_n_chars;
}

return on_finalized_subword_callback(res.str());
Expand Down

0 comments on commit af05de5

Please sign in to comment.