Skip to content

Commit

Permalink
[ CB ][ SD ] Support streaming with using stop_strings and `include…
Browse files Browse the repository at this point in the history
…_stop_strings` (#1382)

*Details:*:
* Implement streaming with using `stop_strings` in CB like pipelines
* Change `stop_string_match` logic to encode them only once per request
* Do not stream tokens which are matched to the part of a `stop_string`
(Tests was a bit changes in this case according HF does not support
exclude `stop_strings`)

*Tickets:*
* CVS-158463
  • Loading branch information
iefode authored Dec 20, 2024
1 parent dc558c2 commit 83f5638
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 106 deletions.
2 changes: 1 addition & 1 deletion src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,4 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
}

} // namespace genai
} // namespace ov
} // namespace ov
166 changes: 87 additions & 79 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,75 +85,63 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin
return clean_text;
}

std::vector<int64_t> encode_and_process_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) {
// encode stop_string
std::string stop_string_copy = stop_string;
ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string_copy, ov::genai::add_special_tokens(false)).input_ids;
size_t tensor_size = ov_encoded_stop_string.get_size();
std::vector<int64_t> encoded_stop_string(tensor_size);
std::copy_n(ov_encoded_stop_string.data<int64_t>(), tensor_size, encoded_stop_string.begin());
return encoded_stop_string;
}

struct MatchStopStringResult {
size_t to_remove = 0;
// int64_t last_token_id = 0;
// bool is_to_update_last_token = false;
bool is_matched = false;
};

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
/*
For catching stop_string hit we run comparisons character-wise to catch cases where stop string
overlaps with part of another token on both sides or is just a part of a single token.
For every stop_string we iterate over generated tokens starting from the last one and going backwards.
Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
Its characters are compared to the stop_string character at a current_position
(position of a character in the stop_string counting from the last one) - at the beginning position is 0.
When characters match we increase current_position and check if we have a full match already, if not we continue.
If we have already matched some characters (current_position > 0) and next character is not matching
before we reach the full match, then we reset current_position to 0.
*/
std::string prefix = "a";
auto prefix_ov = tokenizer.encode(prefix).input_ids;
std::vector<int64_t> prefix_tokens(prefix_ov.data<int64_t>(), prefix_ov.data<int64_t>() + prefix_ov.get_size());
std::string suffix = "b";
auto suffix_ov = tokenizer.encode(suffix).input_ids;
std::vector<int64_t> suffix_tokens(suffix_ov.data<int64_t>(), suffix_ov.data<int64_t>() + suffix_ov.get_size());

// Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
// and get suffix string that will actually be part of the decoded string so we can remove it correctly
auto wrapped_suffix_tokens = suffix_tokens;
wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end());
std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens);
auto wrapper_pos = wrapped_suffix.find(prefix);
suffix = wrapped_suffix.substr(wrapper_pos + prefix.size());

for (auto stop_string: stop_strings) {
int current_position = 0;
int num_matched_tokens = 0;
// Getting reverse iterator to check tokens starting from the last one generated and going backwards
auto generated_tokens_rit = generated_tokens.rbegin();
std::vector<int64_t> tokens_buffer;
while (generated_tokens_rit != generated_tokens.rend()) {
num_matched_tokens++;
tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit);

std::vector<int64_t> wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens);
std::string wrapped_text = tokenizer.decode(wrapped_tokens);
std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix);

if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "") == 0))) {
generated_tokens_rit++;
continue;
} else {
tokens_buffer.clear();
}
// Checking clean_text characters starting from the last one
for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) {
// On character match increment current_position for the next comparisons
if (*clean_text_rit == *(stop_string.rbegin() + current_position)) {
current_position++;
// If this is the last character from the stop_string we have a match
if ((stop_string.rbegin() + current_position) == stop_string.rend()) {
return num_matched_tokens;
}
} else if (current_position) {
// Already found matching characters, but the last one didn't match, so we reset current_position
current_position = 0;
// Looking for the match will start over from this character so we decrement iterator
clean_text_rit--;
MatchStopStringResult match_stop_string(Tokenizer& tokenizer,
const TokenIds& generated_tokens,
const std::pair<size_t, std::set<std::string>>& stop_strings,
bool is_include_to_output) {
MatchStopStringResult result;
if (generated_tokens.size() >= stop_strings.first) {
size_t offset = generated_tokens.size() - stop_strings.first;
TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end());
std::string decoded_buffer = tokenizer.decode(buffer);
for (const auto& stop_string : stop_strings.second) {
auto pos = decoded_buffer.find(stop_string);
if (pos != std::string::npos) {
result.is_matched = true;

auto stop_string_len = is_include_to_output ? stop_string.length() : 0;
decoded_buffer = decoded_buffer.substr(0, pos + stop_string_len);
// to remove word splitting symbols from tail
while (decoded_buffer.back() == ' ' || decoded_buffer.back() == '\n') {
decoded_buffer.pop_back();
}
if (decoded_buffer.empty()) {
result.to_remove = buffer.size();
return result;
}

// find token cnt to be removed from sequence by decoding token by token
std::string decoded_partially_string = "";
for (size_t i = 0; i < buffer.size(); ++i) {
decoded_partially_string += tokenizer.decode(TokenIds{buffer[i]});
if (decoded_partially_string.find(decoded_buffer) != std::string::npos) {
result.to_remove = buffer.size() - i - 1;
break;
}
}
return result;
}
generated_tokens_rit++;
}
}
return 0;
return result;
}

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
Expand Down Expand Up @@ -245,7 +233,9 @@ std::map<size_t, int32_t> Sampler::GroupBeamSearcher::get_beam_idxs() {
return next_beams;
}

void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) {
void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits,
SamplerOutput& sampler_output,
const std::pair<size_t, std::set<std::string>>& stop_strings) {
assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups;
Expand Down Expand Up @@ -392,19 +382,17 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
// There's probably a better way to do that, than copying whole vector...
std::vector<int64_t> token_ids = candidate.m_sequence->get_generated_ids();
token_ids.push_back(candidate.m_token_id);
int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings);
if (num_last_matched_tokens) {
auto match_result = match_stop_string(m_tokenizer, token_ids, stop_strings, m_parameters.include_stop_str_in_output);
if (match_result.is_matched) {
// If beam_token does not belong to top num_beams tokens, it should not be added
if (cand_idx >= group_size)
continue;

if(!m_parameters.include_stop_str_in_output) {
// remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point)
candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1);
}
// remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point)
candidate.m_sequence->remove_last_tokens(match_result.to_remove);

// try to finish candidate
try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output);
try_to_finish_candidate(group, candidate);
continue;
}
}
Expand Down Expand Up @@ -576,10 +564,11 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
}

if (!sampling_params.stop_strings.empty()) {
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
if (num_matched_last_tokens) {
if (!sampling_params.include_stop_str_in_output)
running_sequence->remove_last_tokens(num_matched_last_tokens);
auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id());
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output);
if (match_result.is_matched) {
running_sequence->remove_last_tokens(match_result.to_remove);

running_sequence->set_status(SequenceStatus::FINISHED);
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
dropped_seq_ids.push_back(running_sequence->get_id());
Expand Down Expand Up @@ -741,6 +730,19 @@ float get_p_prime(Sequence::Ptr& running_sequence,
return p_prime;
}

std::pair<size_t, std::set<std::string>>
process_stop_strings(const std::set<std::string>& stop_strings, Tokenizer& tokenizer) {
std::pair<size_t, std::set<std::string>> result;
for (const auto& stop_string : stop_strings) {
auto encoded_stop_string = encode_and_process_string(stop_string, tokenizer);
if (result.first < encoded_stop_string.size()) {
result.first = encoded_stop_string.size();
}
result.second.insert(stop_string);
}
return result;
}

SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
ov::Tensor logits,
bool is_validation_mode_enabled) {
Expand All @@ -764,6 +766,12 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
if (!m_logit_processors.count(request_id)) {
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}
if (!m_stop_strings.count(request_id)) {
auto processed_stop_string = process_stop_strings(sampling_params.stop_strings, m_tokenizer);
m_stop_strings.insert({request_id, processed_stop_string});
sequence_group->set_stream_window_size(processed_stop_string.first);
}
auto& stop_strings = m_stop_strings.at(request_id);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
Expand Down Expand Up @@ -873,7 +881,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// current algorithm already adds new tokens to running sequences and
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output);
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings);

// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
Expand All @@ -886,8 +894,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);
sequence_group->notify_handle();
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
}
Expand Down Expand Up @@ -926,6 +933,7 @@ void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig
void Sampler::clear_request_info(uint64_t request_id) {
m_beam_search_info.erase(request_id);
m_logit_processors.erase(request_id);
m_stop_strings.erase(request_id);
}

int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) {
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Sampler {
size_t seed = rng_engine.default_seed;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;
// { request_id, { max_encoded_len, { stop_strings }}}
std::map<int64_t, std::pair<size_t, std::set<std::string>>> m_stop_strings;

Tokenizer m_tokenizer;

Expand Down Expand Up @@ -120,7 +122,7 @@ class Sampler::GroupBeamSearcher {
public:
explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer);

void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output);
void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::pair<size_t, std::set<std::string>>& stop_strings);
void finalize(SamplerOutput& sampler_output);
std::map<size_t, int32_t> get_beam_idxs();
};
Expand Down
Loading

0 comments on commit 83f5638

Please sign in to comment.