Skip to content

Commit

Permalink
[ CB ][ SD ] Implement streaming with using
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Dec 13, 2024
1 parent d17f716 commit 72d9348
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ int main(int argc, char* argv[]) try {
main_model_path,
main_device,
ov::genai::draft_model(draft_model_path, draft_device),
ov::genai::scheduler_config(scheduler_config));
ov::genai::scheduler_config(scheduler_config)
);

auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
Expand Down
8 changes: 5 additions & 3 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,11 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
if (streamer_ptr && generations.at(0)->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size());
continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0));
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (!streamer_ptr->put(gen_token)) {
break;
}
}
}
}

Expand Down
160 changes: 83 additions & 77 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,75 +85,43 @@ std::string clean_wrapped_text(const std::string& wrapped_text, const std::strin
return clean_text;
}

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
int match_stop_string(Tokenizer & tokenizer, const TokenIds & generated_tokens, const std::set<std::string> & stop_strings) {
/*
For catching stop_string hit we run comparisons character-wise to catch cases where stop string
overlaps with part of another token on both sides or is just a part of a single token.
For every stop_string we iterate over generated tokens starting from the last one and going backwards.
Every token is wrapped with prefix tokens to ensure tokenizer doesn't remove prefix whitespace of the actual token.
After that all tokens are decoded and prefix is removed from the decoded text, so we end up with decoded token.
Its characters are compared to the stop_string character at a current_position
(position of a character in the stop_string counting from the last one) - at the beginning position is 0.
When characters match we increase current_position and check if we have a full match already, if not we continue.
If we have already matched some characters (current_position > 0) and next character is not matching
before we reach the full match, then we reset current_position to 0.
*/
std::string prefix = "a";
auto prefix_ov = tokenizer.encode(prefix).input_ids;
std::vector<int64_t> prefix_tokens(prefix_ov.data<int64_t>(), prefix_ov.data<int64_t>() + prefix_ov.get_size());
std::string suffix = "b";
auto suffix_ov = tokenizer.encode(suffix).input_ids;
std::vector<int64_t> suffix_tokens(suffix_ov.data<int64_t>(), suffix_ov.data<int64_t>() + suffix_ov.get_size());

// Since whitespace can be added at the beginning of the suffix we also try to capture that behavior here
// and get suffix string that will actually be part of the decoded string so we can remove it correctly
auto wrapped_suffix_tokens = suffix_tokens;
wrapped_suffix_tokens.insert(wrapped_suffix_tokens.begin(), prefix_tokens.begin(), prefix_tokens.end());
std::string wrapped_suffix = tokenizer.decode(wrapped_suffix_tokens);
auto wrapper_pos = wrapped_suffix.find(prefix);
suffix = wrapped_suffix.substr(wrapper_pos + prefix.size());

for (auto stop_string: stop_strings) {
int current_position = 0;
int num_matched_tokens = 0;
// Getting reverse iterator to check tokens starting from the last one generated and going backwards
auto generated_tokens_rit = generated_tokens.rbegin();
std::vector<int64_t> tokens_buffer;
while (generated_tokens_rit != generated_tokens.rend()) {
num_matched_tokens++;
tokens_buffer.insert(tokens_buffer.begin(), *generated_tokens_rit);

std::vector<int64_t> wrapped_tokens = wrap_tokens(tokens_buffer, prefix_tokens, suffix_tokens);
std::string wrapped_text = tokenizer.decode(wrapped_tokens);
std::string clean_text = clean_wrapped_text(wrapped_text, prefix, suffix);

if (clean_text == "" || (clean_text.size() >= 3 && (clean_text.compare(clean_text.size() - 3, 3, "") == 0))) {
generated_tokens_rit++;
continue;
} else {
tokens_buffer.clear();
}
// Checking clean_text characters starting from the last one
for (auto clean_text_rit = clean_text.rbegin(); clean_text_rit != clean_text.rend(); clean_text_rit++) {
// On character match increment current_position for the next comparisons
if (*clean_text_rit == *(stop_string.rbegin() + current_position)) {
current_position++;
// If this is the last character from the stop_string we have a match
if ((stop_string.rbegin() + current_position) == stop_string.rend()) {
return num_matched_tokens;
}
} else if (current_position) {
// Already found matching characters, but the last one didn't match, so we reset current_position
current_position = 0;
// Looking for the match will start over from this character so we decrement iterator
clean_text_rit--;
}
// { stop_string, { encoded_stop_string_len, matched_tokens_cnt }}
std::map<std::string, std::pair<size_t, size_t>>
match_stop_string(const TokenIds & generated_tokens, const std::map<std::string, TokenIds>& stop_strings) {
std::map<std::string, std::pair<size_t, size_t>> matched_stop_strings;
for (const auto& stop_string : stop_strings) {
TokenIds encoded_stop_string = stop_string.second;
auto first_matched_token_it = std::find(generated_tokens.rbegin(), generated_tokens.rend(), encoded_stop_string.front());
size_t matched_token_cnt = 0;
if (first_matched_token_it != generated_tokens.rend()) {
auto dist = std::distance(generated_tokens.rbegin(), first_matched_token_it) + 1;

auto generated_pos_start = generated_tokens.begin() + (generated_tokens.size() - dist);
auto generated_pos_end = generated_tokens.end();
TokenIds substring_generated(generated_pos_start, generated_pos_end);

auto stop_string_pos_start = encoded_stop_string.begin();
auto stop_string_pos_end = stop_string_pos_start + dist;
TokenIds substring_stop_str(stop_string_pos_start, stop_string_pos_end);
if (substring_generated == substring_stop_str) {
matched_token_cnt = dist;
}
generated_tokens_rit++;
}
matched_stop_strings.insert({stop_string.first, { encoded_stop_string.size(), matched_token_cnt }});
}
return 0;
return matched_stop_strings;
}

// { is_matched_any_stop_string, { is_matched_any_stop_string ? encoded_stop_string_len : max_matched_len } }
std::pair<bool, size_t> is_stop_generation(const std::map<std::string, std::pair<size_t, size_t>>& match_stop_string_res) {
size_t max_matched_len = 0;
for (const auto& stop_string : match_stop_string_res) {
if (stop_string.second.first == stop_string.second.second) {
return { true, stop_string.second.second};
}
max_matched_len = std::max(max_matched_len, stop_string.second.second);
}
return { false, max_matched_len };
}

// Return number of last tokens that match one of the stop_strings. If there's no match 0 is returned.
Expand Down Expand Up @@ -245,7 +213,7 @@ std::map<size_t, int32_t> Sampler::GroupBeamSearcher::get_beam_idxs() {
return next_beams;
}

void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) {
void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::map<std::string, TokenIds>& stop_strings) {
assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = m_parameters.num_beams / m_parameters.num_beam_groups;
Expand Down Expand Up @@ -387,25 +355,29 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
continue;
}

if (!m_parameters.stop_strings.empty()) {
if (!stop_strings.empty()) {
// We need to include candidate token to already generated tokens to check if stop string has been generated
// There's probably a better way to do that, than copying whole vector...
std::vector<int64_t> token_ids = candidate.m_sequence->get_generated_ids();
token_ids.push_back(candidate.m_token_id);
int num_last_matched_tokens = match_stop_string(m_tokenizer, token_ids, m_sequence_group->get_sampling_parameters().stop_strings);
if (num_last_matched_tokens) {
size_t num_last_matched_tokens;
bool is_matched;
std::tie(is_matched, num_last_matched_tokens) = is_stop_generation(match_stop_string(token_ids, stop_strings));
if (is_matched) {
// If beam_token does not belong to top num_beams tokens, it should not be added
if (cand_idx >= group_size)
continue;

if(!m_parameters.include_stop_str_in_output) {
if (!m_parameters.include_stop_str_in_output) {
// remove tokens that match stop_string from output (last token is not included in candidate.m_sequence at this point)
candidate.m_sequence->remove_last_tokens(num_last_matched_tokens - 1);
candidate.m_sequence->remove_last_tokens(num_last_matched_tokens);
}

// try to finish candidate
try_to_finish_candidate(group, candidate, m_parameters.include_stop_str_in_output);
continue;
} else {
// todo: iefode
}
}

Expand Down Expand Up @@ -575,14 +547,19 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
continue;
}

auto stop_strings = get_stop_strings(sequence_group);
if (!sampling_params.stop_strings.empty()) {
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
if (num_matched_last_tokens) {
size_t num_matched_last_tokens;
bool is_matched;
std::tie(is_matched, num_matched_last_tokens) = is_stop_generation(match_stop_string(running_sequence->get_generated_ids(), stop_strings));
if (is_matched) {
if (!sampling_params.include_stop_str_in_output)
running_sequence->remove_last_tokens(num_matched_last_tokens);
running_sequence->set_status(SequenceStatus::FINISHED);
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
dropped_seq_ids.push_back(running_sequence->get_id());
} else if (!sampling_params.include_stop_str_in_output) {
sequence_group->set_num_tokens_not_to_stream(num_matched_last_tokens);
}
}
}
Expand Down Expand Up @@ -764,6 +741,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
if (!m_logit_processors.count(request_id)) {
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}
auto stop_strings = get_stop_strings(sequence_group);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
Expand Down Expand Up @@ -873,7 +851,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// current algorithm already adds new tokens to running sequences and
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output);
m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings);

// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
Expand All @@ -886,8 +864,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);
sequence_group->notify_handle();
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
}
Expand Down Expand Up @@ -918,6 +895,35 @@ LogitProcessor& Sampler::get_logit_processor(uint64_t request_id) {
return m_logit_processors.at(request_id);
}

TokenIds encode_and_process_stop_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) {
// encode stop_string
ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string).input_ids;
size_t tensor_size = ov_encoded_stop_string.get_size();
TokenIds source_encoded_stop_string(tensor_size), encoded_stop_string;
std::copy_n(ov_encoded_stop_string.data<int64_t>(), tensor_size, source_encoded_stop_string.begin());
// remove special symbols
for (const auto& token_id : source_encoded_stop_string) {
if (token_id != tokenizer.get_bos_token_id() &&
token_id != tokenizer.get_eos_token_id() &&
token_id != tokenizer.get_pad_token_id()) {
encoded_stop_string.push_back(token_id);
}
}
return encoded_stop_string;
}

std::map<std::string, TokenIds>& Sampler::get_stop_strings(const SequenceGroup::Ptr& sequence_group) {
const auto request_id = sequence_group->get_request_id();
if (!m_encoded_stop_strings.count(request_id)) {
std::map<std::string, TokenIds> stop_strings;
for (const auto& stop_string : sequence_group->get_sampling_parameters().stop_strings) {
stop_strings.insert({ stop_string, encode_and_process_stop_string(stop_string, m_tokenizer) });
}
m_encoded_stop_strings.insert({ request_id, stop_strings });
}
return m_encoded_stop_strings.at(request_id);
}


void Sampler::create_logit_processor(uint64_t request_id, const GenerationConfig& sampling_params, const TokenIds& prompt) {
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, prompt)});
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ class Sampler {

bool validate_candidate(Sequence::Ptr running_sequence, size_t& token_idx, Token& sampled_token,
bool& is_extend_sequence, size_t& max_removed_tokens, bool do_sample);
std::map<std::string, TokenIds>& get_stop_strings(const SequenceGroup::Ptr& sequence_group);


// request ID => beam search tracking information
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

std::mt19937 rng_engine;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;
// { request_id, { stop_string, encoded_stop_strings }}
std::map<uint64_t, std::map<std::string, TokenIds>> m_encoded_stop_strings;

Tokenizer m_tokenizer;

Expand Down Expand Up @@ -115,7 +119,7 @@ class Sampler::GroupBeamSearcher {
public:
explicit GroupBeamSearcher(SequenceGroup::Ptr sequence_group, Tokenizer tokenizer);

void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output);
void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output, const std::map<std::string, TokenIds>& stop_strings);
void finalize(SamplerOutput& sampler_output);
std::map<size_t, int32_t> get_beam_idxs();
};
Expand Down
33 changes: 28 additions & 5 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class SequenceGroup {
// flag to enable/disable token generation, e.g. in speculative decoding scenario
bool m_is_gen_paused = false;

size_t m_num_streamed_tokens = 0, m_num_not_streamed_tokens = 0;


SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: m_request_id(request_id),
Expand Down Expand Up @@ -390,6 +392,10 @@ class SequenceGroup {
return m_num_processed_tokens;
}

void set_num_tokens_not_to_stream(size_t k) {
m_num_not_streamed_tokens = k;
}

/**
* Registers within the sequence group that a given amount of tokens
* has been evicted from the underlying KV cache.
Expand Down Expand Up @@ -602,6 +608,12 @@ class SequenceGroup {
// todo: check seq.is_finished() to generate without several </s>
// or is it ok to use padding?
auto output = sequence->get_last_generation_output(token_cnt);
if (token_cnt >= m_num_not_streamed_tokens) {
for (size_t i = 0; i < m_num_not_streamed_tokens; ++i) {
output.generated_ids.pop_back();
output.generated_log_probs.pop_back();
}
}
if (m_sampling_params.echo && !m_has_echoed) {
output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end());
output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end());
Expand All @@ -612,7 +624,7 @@ class SequenceGroup {
m_generation_stream->push(std::move(outputs));
}

void notify_handle(size_t num_output_token_to_push = 0) {
void notify_handle() {
if (out_of_memory()) {
set_generation_status(GenerationStatus::IGNORED);
} else if (has_finished()) {
Expand All @@ -626,10 +638,21 @@ class SequenceGroup {
} else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) {
// We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output
// (after stop string is detected its tokens are already sent)
if (num_total_seqs() == 1 &&
(m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) {
if (num_output_token_to_push)
push_partial_outputs(num_output_token_to_push);
if (num_total_seqs() == 1) {
const auto generated_len = m_sequences.front()->get_generated_len();
// speculative decoding draft handling
if (generated_len < m_num_streamed_tokens) {
m_num_streamed_tokens = generated_len;
}
OPENVINO_ASSERT(generated_len >= m_num_streamed_tokens);
auto delta = generated_len - m_num_streamed_tokens;

size_t num_output_token_to_push = generated_len - m_num_streamed_tokens;
if (m_sampling_params.include_stop_str_in_output) {
m_num_not_streamed_tokens = 0;
}
push_partial_outputs(num_output_token_to_push);
m_num_streamed_tokens += (num_output_token_to_push - m_num_not_streamed_tokens);
} else if (has_finished() || out_of_memory()) {
push_outputs();
}
Expand Down
Loading

0 comments on commit 72d9348

Please sign in to comment.