Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ CB ][ SD ] Support streaming with using stop_strings and include_stop_strings #1382

Merged
merged 24 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,4 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
}

} // namespace genai
} // namespace ov
} // namespace ov
166 changes: 87 additions & 79 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,75 +85,63 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin
return clean_text;
}

std::vector<int64_t> encode_and_process_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) {
// encode stop_string
std::string stop_string_copy = stop_string;
ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string_copy, ov::genai::add_special_tokens(false)).input_ids;
size_t tensor_size = ov_encoded_stop_string.get_size();
std::vector<int64_t> encoded_stop_string(tensor_size);
std::copy_n(ov_encoded_stop_string.data<int64_t>(), tensor_size, encoded_stop_string.begin());
return encoded_stop_string;
}

struct MatchStopStringResult {
size_t to_remove = 0;
// int64_t last_token_id = 0;
// bool is_to_update_last_token = false;
bool is_matched = false;
};

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
/*
For catching stop_string hit we run comparisons character-wise to catch cases where stop string
overlaps with part of another token on both sides or is just a part of a single token.
For every stop_string we iterate over generated tokens starting from the last one and going backwards.
Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
Its characters are compared to the stop_string character at a current_position
(position of a character in the stop_string counting from the last one) - at the beginning position is 0.
When characters match we increase current_position and check if we have a full match already, if not we continue.
If we have already matched some characters (current_position > 0) and next character is not matching
before we reach the full match, then we reset current_position to 0.
*/
std::string prefix = "a";
auto prefix_ov = tokenizer.encode(prefix).input_ids;
std::vector<int64_t> prefix_tokens(prefix_ov.data<int64_t>(), prefix_ov.data<int64_t>() + prefix_ov.get_size());
std::string suffix = "b";
auto suffix_ov = tokenizer.encode(suffix).input_ids;
std::vector<int64_t> suffix_tokens(suffix_ov.data<int64_t>(), suffix_ov.data<int64_t>() + suffix_ov.get_size());

// Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
// and get suffix string that will actually be part of the decoded string so we can remove it correctly
auto wrapped_suffix_tokens = suffix_tokens;
wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end());
std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens);
auto wrapper_pos = wrapped_suffix.find(prefix);
suffix = wrapped_suffix.substr(wrapper_pos + prefix.size());

for (auto stop_string: stop_strings) {
int current_position = 0;
int num_matched_tokens = 0;
// Getting reverse iterator to check tokens starting from the last one generated and going backwards
auto generated_tokens_rit = generated_tokens.rbegin();
std::vector<int64_t> tokens_buffer;
while (generated_tokens_rit != generated_tokens.rend()) {
num_matched_tokens++;
tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit);

std::vector<int64_t> wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens);
std::string wrapped_text = tokenizer.decode(wrapped_tokens);
std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix);

if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "") == 0))) {
generated_tokens_rit++;
continue;
} else {
tokens_buffer.clear();
}
// Checking clean_text characters starting from the last one
for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) {
// On character match increment current_position for the next comparisons
if (*clean_text_rit == *(stop_string.rbegin() + current_position)) {
current_position++;
// If this is the last character from the stop_string we have a match
if ((stop_string.rbegin() + current_position) == stop_string.rend()) {
return num_matched_tokens;
}
} else if (current_position) {
// Already found matching characters, but the last one didn't match, so we reset current_position
current_position = 0;
// Looking for the match will start over from this character so we decrement iterator
clean_text_rit--;
MatchStopStringResult match_stop_string(Tokenizer& tokenizer,
const TokenIds& generated_tokens,
const std::pair<size_t, std::set<std::string>>& stop_strings,
bool is_include_to_output) {
MatchStopStringResult result;
if (generated_tokens.size() >= stop_strings.first) {
size_t offset = generated_tokens.size() - stop_strings.first;
TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end());
std::string decoded_buffer = tokenizer.decode(buffer);
for (const auto& stop_string : stop_strings.second) {
auto pos = decoded_buffer.find(stop_string);
if (pos != std::string::npos) {
result.is_matched = true;

auto stop_string_len = is_include_to_output ? stop_string.length() : 0;
decoded_buffer = decoded_buffer.substr(0, pos + stop_string_len);
// to remove word splitting symbols from tail
while (decoded_buffer.back() == ' ' || decoded_buffer.back() == '\n') {
decoded_buffer.pop_back();
}
if (decoded_buffer.empty()) {
result.to_remove = buffer.size();
return result;
}

// find token cnt to be removed from sequence by decoding token by token
std::string decoded_partially_string = "";
for (size_t i = 0; i < buffer.size(); ++i) {
decoded_partially_string += tokenizer.decode(TokenIds{buffer[i]});
if (decoded_partially_string.find(decoded_buffer) != std::string::npos) {
result.to_remove = buffer.size() - i - 1;
break;
}
}
return result;
}
generated_tokens_rit++;
}
}
return 0;
return result;
}

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
Expand Down Expand Up @@ -245,7 +233,9 @@ std::map<size_t, int32_t> Sampler::GroupBeamSearcher::get_beam_idxs() {
return next_beams;
}

void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) {
void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits,
SamplerOutput& sampler_output,
const std::pair<size_t, std::set<std::string>>& stop_strings) {
assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups;
Expand Down Expand Up @@ -392,19 +382,17 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
// There's probably a better way to do that, than copying whole vector...
std::vector<int64_t> token_ids = candidate.m_sequence->get_generated_ids();
token_ids.push_back(candidate.m_token_id);
int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings);
if (num_last_matched_tokens) {
auto match_result = match_stop_string(m_tokenizer, token_ids, stop_strings, m_parameters.include_stop_str_in_output);
if (match_result.is_matched) {
// If beam_token does not belong to top num_beams tokens, it should not be added
if (cand_idx >= group_size)
continue;

if(!m_parameters.include_stop_str_in_output) {
// remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point)
candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1);
}
// remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point)
candidate.m_sequence->remove_last_tokens(match_result.to_remove);

// try to finish candidate
try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output);
try_to_finish_candidate(group, candidate);
continue;
}
}
Expand Down Expand Up @@ -576,10 +564,11 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
}

if (!sampling_params.stop_strings.empty()) {
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
if (num_matched_last_tokens) {
if (!sampling_params.include_stop_str_in_output)
running_sequence->remove_last_tokens(num_matched_last_tokens);
auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id());
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output);
if (match_result.is_matched) {
running_sequence->remove_last_tokens(match_result.to_remove);

running_sequence->set_status(SequenceStatus::FINISHED);
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
dropped_seq_ids.push_back(running_sequence->get_id());
Expand Down Expand Up @@ -741,6 +730,19 @@ float get_p_prime(Sequence::Ptr& running_sequence,
return p_prime;
}

std::pair<size_t, std::set<std::string>>
process_stop_strings(const std::set<std::string>& stop_strings, Tokenizer& tokenizer) {
std::pair<size_t, std::set<std::string>> result;
for (const auto& stop_string : stop_strings) {
auto encoded_stop_string = encode_and_process_string(stop_string, tokenizer);
if (result.first < encoded_stop_string.size()) {
result.first = encoded_stop_string.size();
}
result.second.insert(stop_string);
}
return result;
}

SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
ov::Tensor logits,
bool is_validation_mode_enabled) {
Expand All @@ -764,6 +766,12 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
if (!m_logit_processors.count(request_id)) {
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}
if (!m_stop_strings.count(request_id)) {
auto processed_stop_string = process_stop_strings(sampling_params.stop_strings, m_tokenizer);
m_stop_strings.insert({request_id, processed_stop_string});
sequence_group->set_stream_window_size(processed_stop_string.first);
}
auto& stop_strings = m_stop_strings.at(request_id);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
Expand Down Expand Up @@ -873,7 +881,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// current algorithm already adds new tokens to running sequences and
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output);
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings);

// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
Expand All @@ -886,8 +894,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);
sequence_group->notify_handle();
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
}
Expand Down Expand Up @@ -926,6 +933,7 @@ void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig
void Sampler::clear_request_info(uint64_t request_id) {
m_beam_search_info.erase(request_id);
m_logit_processors.erase(request_id);
m_stop_strings.erase(request_id);
}

int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::GenerationConfig& sampling_params) {
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Sampler {
size_t seed = rng_engine.default_seed;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;
// { request_id, { max_encoded_len, { stop_strings }}}
std::map<int64_t, std::pair<size_t, std::set<std::string>>> m_stop_strings;
iefode marked this conversation as resolved.
Show resolved Hide resolved

Tokenizer m_tokenizer;

Expand Down Expand Up @@ -120,7 +122,7 @@ class Sampler::GroupBeamSearcher {
public:
explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer);

void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output);
void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::pair<size_t, std::set<std::string>>& stop_strings);
void finalize(SamplerOutput& sampler_output);
std::map<size_t, int32_t> get_beam_idxs();
};
Expand Down
Loading
Loading