diff --git a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp index dc6761879c..3c2c5b04a9 100644 --- a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp +++ b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp @@ -34,7 +34,8 @@ int main(int argc, char* argv[]) try { main_model_path, main_device, ov::genai::draft_model(draft_model_path, draft_device), - ov::genai::scheduler_config(scheduler_config)); + ov::genai::scheduler_config(scheduler_config) + ); auto streamer = [](std::string subword) { std::cout << subword << std::flush; diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index d27e8934dc..1e42f5b2d9 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -285,9 +285,11 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { std::unordered_map token = generations.at(0).get()->back(); - OPENVINO_ASSERT(1 == token.size()); - OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size()); - continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0)); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (!streamer_ptr->put(gen_token)) { + break; + } + } } } diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index f77463d767..294d9b6ace 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -85,75 +85,43 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin return clean_text; } -// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. -int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set & stop_strings) { - /* - For catching stop_string hit we run comparisons character-wise to catch cases where stop string - overlaps with part of another token on both sides or is just a part of a single token. - For every stop_string we iterate over generated tokens starting from the last one and going backwards. - Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token. - After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token. - Its characters are compared to the stop_string character at a current_position - (position of a character in the stop_string counting from the last one) - at the beginning position is 0. - When characters match we increase current_position and check if we have a full match already, if not we continue. - If we have already matched some characters (current_position > 0) and next character is not matching - before we reach the full match, then we reset current_position to 0. - */ - std::string prefix = "a"; - auto prefix_ov = tokenizer.encode(prefix).input_ids; - std::vector prefix_tokens(prefix_ov.data(), prefix_ov.data() + prefix_ov.get_size()); - std::string suffix = "b"; - auto suffix_ov = tokenizer.encode(suffix).input_ids; - std::vector suffix_tokens(suffix_ov.data(), suffix_ov.data() + suffix_ov.get_size()); - - // Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here - // and get suffix string that will actually be part of the decoded string so we can remove it correctly - auto wrapped_suffix_tokens = suffix_tokens; - wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end()); - std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens); - auto wrapper_pos = wrapped_suffix.find(prefix); - suffix = wrapped_suffix.substr(wrapper_pos + prefix.size()); - - for (auto stop_string: stop_strings) { - int current_position = 0; - int num_matched_tokens = 0; - // Getting reverse iterator to check tokens starting from the last one generated and going backwards - auto generated_tokens_rit = generated_tokens.rbegin(); - std::vector tokens_buffer; - while (generated_tokens_rit != generated_tokens.rend()) { - num_matched_tokens++; - tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit); - - std::vector wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens); - std::string wrapped_text = tokenizer.decode(wrapped_tokens); - std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix); - - if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "�") == 0))) { - generated_tokens_rit++; - continue; - } else { - tokens_buffer.clear(); - } - // Checking clean_text characters starting from the last one - for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) { - // On character match increment current_position for the next comparisons - if (*clean_text_rit == *(stop_string.rbegin() + current_position)) { - current_position++; - // If this is the last character from the stop_string we have a match - if ((stop_string.rbegin() + current_position) == stop_string.rend()) { - return num_matched_tokens; - } - } else if (current_position) { - // Already found matching characters, but the last one didn't match, so we reset current_position - current_position = 0; - // Looking for the match will start over from this character so we decrement iterator - clean_text_rit--; - } +// { stop_string, { encoded_stop_string_len, matched_tokens_cnt }} +std::map> +match_stop_string(const TokenIds & generated_tokens, const std::map& stop_strings) { + std::map> matched_stop_strings; + for (const auto& stop_string : stop_strings) { + TokenIds encoded_stop_string = stop_string.second; + auto first_matched_token_it = std::find(generated_tokens.rbegin(), generated_tokens.rend(), encoded_stop_string.front()); + size_t matched_token_cnt = 0; + if (first_matched_token_it != generated_tokens.rend()) { + auto dist = std::distance(generated_tokens.rbegin(), first_matched_token_it) + 1; + + auto generated_pos_start = generated_tokens.begin() + (generated_tokens.size() - dist); + auto generated_pos_end = generated_tokens.end(); + TokenIds substring_generated(generated_pos_start, generated_pos_end); + + auto stop_string_pos_start = encoded_stop_string.begin(); + auto stop_string_pos_end = stop_string_pos_start + dist; + TokenIds substring_stop_str(stop_string_pos_start, stop_string_pos_end); + if (substring_generated == substring_stop_str) { + matched_token_cnt = dist; } - generated_tokens_rit++; } + matched_stop_strings.insert({stop_string.first, { encoded_stop_string.size(), matched_token_cnt }}); } - return 0; + return matched_stop_strings; +} + +// { is_matched_any_stop_string, { is_matched_any_stop_string ? encoded_stop_string_len : max_matched_len } } +std::pair is_stop_generation(const std::map>& match_stop_string_res) { + size_t max_matched_len = 0; + for (const auto& stop_string : match_stop_string_res) { + if (stop_string.second.first == stop_string.second.second) { + return { true, stop_string.second.second}; + } + max_matched_len = std::max(max_matched_len, stop_string.second.second); + } + return { false, max_matched_len }; } // Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned. @@ -245,7 +213,7 @@ std::map Sampler::GroupBeamSearcher::get_beam_idxs() { return next_beams; } -void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) { +void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::map& stop_strings) { assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups; @@ -387,25 +355,29 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa continue; } - if (!m_parameters.stop_strings.empty()) { + if (!stop_strings.empty()) { // We need to include candidate token to already generated tokens to check if stop string has been generated // There's probably a better way to do that, than copying whole vector... std::vector token_ids = candidate.m_sequence->get_generated_ids(); token_ids.push_back(candidate.m_token_id); - int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings); - if (num_last_matched_tokens) { + size_t num_last_matched_tokens; + bool is_matched; + std::tie(is_matched, num_last_matched_tokens) = is_stop_generation(match_stop_string(token_ids, stop_strings)); + if (is_matched) { // If beam_token does not belong to top num_beams tokens, it should not be added if (cand_idx >= group_size) continue; - if(!m_parameters.include_stop_str_in_output) { + if (!m_parameters.include_stop_str_in_output) { // remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point) - candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1); + candidate.m_sequence->remove_last_tokens(num_last_matched_tokens); } // try to finish candidate try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output); continue; + } else { + // todo: iefode } } @@ -575,14 +547,19 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen continue; } + auto stop_strings = get_stop_strings(sequence_group); if (!sampling_params.stop_strings.empty()) { - int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); - if (num_matched_last_tokens) { + size_t num_matched_last_tokens; + bool is_matched; + std::tie(is_matched, num_matched_last_tokens) = is_stop_generation(match_stop_string(running_sequence->get_generated_ids(), stop_strings)); + if (is_matched) { if (!sampling_params.include_stop_str_in_output) running_sequence->remove_last_tokens(num_matched_last_tokens); running_sequence->set_status(SequenceStatus::FINISHED); running_sequence->set_finish_reason(GenerationFinishReason::STOP); dropped_seq_ids.push_back(running_sequence->get_id()); + } else if (!sampling_params.include_stop_str_in_output) { + sequence_group->set_num_tokens_not_to_stream(num_matched_last_tokens); } } } @@ -764,6 +741,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, if (!m_logit_processors.count(request_id)) { m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())}); } + auto stop_strings = get_stop_strings(sequence_group); auto& logit_processor = m_logit_processors.at(request_id); const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens; ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data); @@ -873,7 +851,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, } // current algorithm already adds new tokens to running sequences and - m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output); + m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings); // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); @@ -886,8 +864,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; - sequence_group->notify_handle(num_output_token_to_push); + sequence_group->notify_handle(); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } @@ -918,6 +895,35 @@ LogitProcessor& Sampler::get_logit_processor(uint64_t request_id) { return m_logit_processors.at(request_id); } +TokenIds encode_and_process_stop_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) { + // encode stop_string + ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string).input_ids; + size_t tensor_size = ov_encoded_stop_string.get_size(); + TokenIds source_encoded_stop_string(tensor_size), encoded_stop_string; + std::copy_n(ov_encoded_stop_string.data(), tensor_size, source_encoded_stop_string.begin()); + // remove special symbols + for (const auto& token_id : source_encoded_stop_string) { + if (token_id != tokenizer.get_bos_token_id() && + token_id != tokenizer.get_eos_token_id() && + token_id != tokenizer.get_pad_token_id()) { + encoded_stop_string.push_back(token_id); + } + } + return encoded_stop_string; +} + +std::map& Sampler::get_stop_strings(const SequenceGroup::Ptr& sequence_group) { + const auto request_id = sequence_group->get_request_id(); + if (!m_encoded_stop_strings.count(request_id)) { + std::map stop_strings; + for (const auto& stop_string : sequence_group->get_sampling_parameters().stop_strings) { + stop_strings.insert({ stop_string, encode_and_process_stop_string(stop_string, m_tokenizer) }); + } + m_encoded_stop_strings.insert({ request_id, stop_strings }); + } + return m_encoded_stop_strings.at(request_id); +} + void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_params, const TokenIds& prompt) { m_logit_processors.insert({request_id, LogitProcessor(sampling_params, prompt)}); diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 0f7876cbf9..a49524ed74 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -50,6 +50,8 @@ class Sampler { bool validate_candidate(Sequence::Ptr running_sequence, size_t& token_idx, Token& sampled_token, bool& is_extend_sequence, size_t& max_removed_tokens, bool do_sample); + std::map& get_stop_strings(const SequenceGroup::Ptr& sequence_group); + // request ID => beam search tracking information std::map m_beam_search_info; @@ -57,6 +59,8 @@ class Sampler { std::mt19937 rng_engine; // { request_id, logit_processor } std::map m_logit_processors; + // { request_id, { stop_string, encoded_stop_strings }} + std::map> m_encoded_stop_strings; Tokenizer m_tokenizer; @@ -115,7 +119,7 @@ class Sampler::GroupBeamSearcher { public: explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer); - void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output); + void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::map& stop_strings); void finalize(SamplerOutput& sampler_output); std::map get_beam_idxs(); }; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 6755255fe8..bc50dbbf7c 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -221,6 +221,8 @@ class SequenceGroup { // flag to enable/disable token generation, e.g. in speculative decoding scenario bool m_is_gen_paused = false; + size_t m_num_streamed_tokens = 0, m_num_not_streamed_tokens = 0; + SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -390,6 +392,10 @@ class SequenceGroup { return m_num_processed_tokens; } + void set_num_tokens_not_to_stream(size_t k) { + m_num_not_streamed_tokens = k; + } + /** * Registers within the sequence group that a given amount of tokens * has been evicted from the underlying KV cache. @@ -602,6 +608,12 @@ class SequenceGroup { // todo: check seq.is_finished() to generate without several // or is it ok to use padding? auto output = sequence->get_last_generation_output(token_cnt); + if (token_cnt >= m_num_not_streamed_tokens) { + for (size_t i = 0; i < m_num_not_streamed_tokens; ++i) { + output.generated_ids.pop_back(); + output.generated_log_probs.pop_back(); + } + } if (m_sampling_params.echo && !m_has_echoed) { output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end()); output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end()); @@ -612,7 +624,7 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } - void notify_handle(size_t num_output_token_to_push = 0) { + void notify_handle() { if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { @@ -626,10 +638,21 @@ class SequenceGroup { } else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) { // We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output // (after stop string is detected its tokens are already sent) - if (num_total_seqs() == 1 && - (m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) { - if (num_output_token_to_push) - push_partial_outputs(num_output_token_to_push); + if (num_total_seqs() == 1) { + const auto generated_len = m_sequences.front()->get_generated_len(); + // speculative decoding draft handling + if (generated_len < m_num_streamed_tokens) { + m_num_streamed_tokens = generated_len; + } + OPENVINO_ASSERT(generated_len >= m_num_streamed_tokens); + auto delta = generated_len - m_num_streamed_tokens; + + size_t num_output_token_to_push = generated_len - m_num_streamed_tokens; + if (m_sampling_params.include_stop_str_in_output) { + m_num_not_streamed_tokens = 0; + } + push_partial_outputs(num_output_token_to_push); + m_num_streamed_tokens += (num_output_token_to_push - m_num_not_streamed_tokens); } else if (has_finished() || out_of_memory()) { push_outputs(); } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 2be67320a9..e4f3b1ad1f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -232,8 +232,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< continue; } std::unordered_map token = main_generations.at(0).get()->back(); - OPENVINO_ASSERT(1 <= token.size()); - OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size()); for (const auto& gen_token : token.begin()->second.generated_ids) { continue_generation = !streamer_ptr->put(gen_token); if (!continue_generation) {