Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Nov 20, 2024
1 parent 7d5dfb3 commit 52d391a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 71 deletions.
63 changes: 28 additions & 35 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,12 +741,11 @@ float get_p_prime(Sequence::Ptr& running_sequence,
return p_prime;
}

std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(SequenceGroup::Ptr sequence_group,
ov::Tensor sequence_group_logits,
LogitProcessor& logit_processor,
bool is_validation_mode_enabled) {
SamplerOutput sampler_output;
size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits<size_t>::max();
SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr sequence_group,
ov::Tensor sequence_group_logits,
LogitProcessor& logit_processor,
bool is_validation_mode_enabled) {
SequenceGroupSamplingInfo sampling_info;
auto num_running_sequences = sequence_group->num_running_seqs();
auto sampling_params = sequence_group->get_sampling_parameters();
// get number of token to be validated
Expand All @@ -769,7 +768,7 @@ std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(Se
OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len;
if (max_num_sampled_token == 0) {
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request);
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, sampling_info.max_removed_tokens_per_request);
break;
}

Expand Down Expand Up @@ -797,13 +796,13 @@ std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(Se
// to create n sequence just in case of `sequence_group->num_total_seqs() == 1` and `sampling_params.num_return_sequences > 1`
if (is_generate_n_tokens) {
const auto forked_seq_ids = create_n_forked_sequences(sequence_group, logit_processor, sampled_token_ids);
sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids});
sampling_info.sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids});
}
sampled_token = sampled_token_ids.front();
// make `_speculative_sampling` in case of previous token was not accepted in speculative decoding
if (!is_validation_passed) {
float p_prime = get_p_prime(running_sequence, sampled_token, token_offset + 1);
max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_offset);
sampling_info.max_removed_tokens_per_request = std::max(sampling_info.max_removed_tokens_per_request, token_offset);
// update prob only in case candidate prob > sampled token prob
if (p_prime > 0.f) {
auto prob = std::exp(sampled_token.m_log_prob);
Expand All @@ -816,7 +815,7 @@ std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(Se
bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed;
if (is_validation_mode_enabled && !is_extend_sequence) {
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token,
is_extend_sequence, max_removed_tokens_per_request, sampling_params.do_sample);
is_extend_sequence, sampling_info.max_removed_tokens_per_request, sampling_params.do_sample);
// doing resample in case of non accepted tokens in specualtive sampling
if (!is_validation_passed && sampling_params.do_sample) {
continue;
Expand All @@ -833,11 +832,11 @@ std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(Se
break;
}
}
min_generated_len = std::min(min_generated_len, running_sequence->get_generated_len());
sampling_info.min_generated_len = std::min(sampling_info.min_generated_len, running_sequence->get_generated_len());
}
align_all_sequence_len(sequence_group, min_generated_len, logit_processor);
align_all_sequence_len(sequence_group, sampling_info.min_generated_len, logit_processor);
for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) {
sampler_output.m_dropped_sequences.push_back(dropped_seq_id);
sampling_info.sampler_output.m_dropped_sequences.push_back(dropped_seq_id);
}
} else if (sampling_params.is_beam_search()) {
uint64_t request_id = sequence_group->get_request_id();
Expand All @@ -853,23 +852,23 @@ std::tuple<SamplerOutput, size_t, size_t> Sampler::sample_from_sequence_group(Se
}

// current algorithm already adds new tokens to running sequences and
beam_searcher->select_next_tokens(sequence_group_logits, sampler_output);
beam_searcher->select_next_tokens(sequence_group_logits, sampling_info.sampler_output);

// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
if (!sequence_group->has_finished() &&
running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) {
// stop sequence by max_new_tokens
beam_searcher->finalize(sampler_output);
beam_searcher->finalize(sampling_info.sampler_output);
}
}
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
OPENVINO_ASSERT(num_tokens_to_process >= sampling_info.max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - sampling_info.max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);

return std::make_tuple(sampler_output, min_generated_len, max_removed_tokens_per_request);
return sampling_info;
}

SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
Expand All @@ -881,7 +880,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];

SamplerOutput sampler_output;
std::unordered_map<SequenceGroup::Ptr, std::future<std::tuple<SamplerOutput, size_t, size_t>>> future_map;
std::unordered_map<SequenceGroup::Ptr, std::future<SequenceGroupSamplingInfo>> future_map;
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
if (!sequence_group->is_scheduled())
Expand All @@ -898,36 +897,30 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}

//std::cout << "\nSequence group ID: " << sequence_group_id << std::endl;
//std::cout << "Sequence group data valid capacity: " << logits.get_size() << std::endl;
//std::cout << "Sequence group data offset: " << vocab_size * currently_processed_tokens << std::endl;

const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
//std::cout << "Sequence group logits tensor size: " << sequence_group_logits.get_size() << std::endl;

// Call sample_from_sequence_group asynchronously
//future_map[sequence_group] = std::async(std::launch::async, &Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits, is_validation_mode_enabled);
future_map[sequence_group] = m_thread_pool.enqueue(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits,
m_logit_processors.at(sequence_group->get_request_id()), is_validation_mode_enabled);
future_map[sequence_group] = m_thread_pool.submit(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits,
m_logit_processors.at(sequence_group->get_request_id()), is_validation_mode_enabled);
}
// accumulate a number of processed tokens
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
}

// Iterate over sequence groups and check if future_map contains the key
for (auto& sequence_group : sequence_groups) {
if (future_map.find(sequence_group) != future_map.end()) {
auto [sequence_group_sampler_output, min_generated_len, max_removed_tokens_per_request] = future_map[sequence_group].get();
// If there is a future assigned to a sequence group we read it's result (blocking if results not available yet)
auto sequence_group_sampling_info = future_map[sequence_group].get();

// Merge sequence_group_sampler_output into sampler_output
// Merge sampler output from sequence group to the main one
sampler_output.m_dropped_sequences.insert(
sampler_output.m_dropped_sequences.end(),
sequence_group_sampler_output.m_dropped_sequences.begin(),
sequence_group_sampler_output.m_dropped_sequences.end()
sequence_group_sampling_info.sampler_output.m_dropped_sequences.begin(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.end()
);

for (const auto& forked_seq : sequence_group_sampler_output.m_forked_sequences) {
for (const auto& forked_seq : sequence_group_sampling_info.sampler_output.m_forked_sequences) {
sampler_output.m_forked_sequences[forked_seq.first].insert(
sampler_output.m_forked_sequences[forked_seq.first].end(),
forked_seq.second.begin(),
Expand All @@ -939,8 +932,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1;
if (sequence_group_sampling_info.max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id());
logit_processor.update_generated_len(min_processed_tokens);
Expand Down
10 changes: 8 additions & 2 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ struct SamplerOutput {
std::unordered_map<uint64_t, std::list<uint64_t>> m_forked_sequences;
};

struct SequenceGroupSamplingInfo {
SamplerOutput sampler_output;
size_t max_removed_tokens_per_request = 0;
size_t min_generated_len = std::numeric_limits<size_t>::max();
};

class Sampler {
class GroupBeamSearcher;

Expand Down Expand Up @@ -68,8 +74,8 @@ class Sampler {
Sampler() = default;
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

std::tuple<SamplerOutput, size_t, size_t> sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits,
LogitProcessor& logit_processor, bool is_validation_mode_enabled = false);
SequenceGroupSamplingInfo sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits,
LogitProcessor& logit_processor, bool is_validation_mode_enabled = false);
SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }

Expand Down
61 changes: 27 additions & 34 deletions src/cpp/src/threadpool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,72 +6,65 @@
#include <queue>
#include <thread>
#include <utility>
using namespace std;

// Class that represents a simple thread pool
class ThreadPool {

private:
vector<thread> threads_;
queue<function<void()>> tasks_;
mutex queue_mutex_;
condition_variable cv_;
bool stop_ = false;
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
std::mutex queue_mutex;
std::condition_variable cv;
bool stop = false;

public:
// Constructor to create a thread pool with given
// number of threads
ThreadPool(size_t num_threads = thread::hardware_concurrency())
ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
{
// Creating worker threads
for (size_t i = 0; i < num_threads; ++i) {
threads_.emplace_back([this] {
threads.emplace_back([this] {
while (true) {
function<void()> task;
std::function<void()> task;
{
unique_lock<mutex> lock(queue_mutex_);
cv_.wait(lock, [this] {
return !tasks_.empty() || stop_;
std::unique_lock<std::mutex> lock(queue_mutex);
cv.wait(lock, [this] {
return !tasks.empty() || stop;
});
if (stop_ && tasks_.empty()) {
if (stop && tasks.empty()) {
return;
}
task = move(tasks_.front());
tasks_.pop();
task = move(tasks.front());
tasks.pop();
}
task();
}
});
}
}

// Destructor to stop the thread pool
~ThreadPool()
{
{
unique_lock<mutex> lock(queue_mutex_);
stop_ = true;
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
cv_.notify_all();
for (auto& thread : threads_) {
cv.notify_all();
for (auto& thread : threads) {
thread.join();
}
}

// Enqueue task for execution by the thread pool
template <typename F, typename... Args>
auto enqueue(F&& f, Args&&... args) -> future<result_of_t<F(Args...)>>
auto submit(F&& f, Args&&... args) -> std::future<std::invoke_result_t<F, Args...>>
{
using return_type = invoke_result_t<F, Args...>;
auto task = make_shared<packaged_task<return_type()>>(
bind(forward<F>(f), forward<Args>(args)...)
using return_type = std::invoke_result_t<F, Args...>;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
future<return_type> res = task->get_future();
std::future<return_type> result = task->get_future();
{
unique_lock<mutex> lock(queue_mutex_);
tasks_.emplace([task]() { (*task)(); });
std::unique_lock<std::mutex> lock(queue_mutex);
tasks.emplace([task]() { (*task)(); });
}
cv_.notify_one();
return res;
cv.notify_one();
return result;
}
};

0 comments on commit 52d391a

Please sign in to comment.