Skip to content

Commit

Permalink
use tbb instead of threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Nov 20, 2024
1 parent 52d391a commit 59a4e6d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 50 deletions.
4 changes: 3 additions & 1 deletion src/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ if(TARGET openvino_tokenizers)
endif()
add_library(openvino::genai ALIAS ${TARGET_NAME})

find_package(TBB REQUIRED)

target_include_directories(${TARGET_NAME}
PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:runtime/include>"
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src")

target_include_directories(${TARGET_NAME} SYSTEM PRIVATE "${safetensors.h_SOURCE_DIR}")

target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime PRIVATE openvino::threading nlohmann_json::nlohmann_json jinja2cpp)
target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime TBB::tbb PRIVATE openvino::threading nlohmann_json::nlohmann_json jinja2cpp)

target_compile_features(${TARGET_NAME} PUBLIC cxx_std_17)

Expand Down
117 changes: 68 additions & 49 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

#include <future>
#include "oneapi/tbb.h"
#include "sampler.hpp"

namespace ov::genai {
Expand Down Expand Up @@ -880,69 +881,87 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];

SamplerOutput sampler_output;
std::unordered_map<SequenceGroup::Ptr, std::future<SequenceGroupSamplingInfo>> future_map;
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
//std::mutex sampler_output_mutex;
tbb::spin_mutex sampler_output_mutex;
std::unordered_map<size_t, size_t> sequence_group_offsets;

// First sequential pass to collect metadata and prepare for parallel processing
size_t last_request_id = 0;
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
if (!sequence_group->is_scheduled())
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);

if (sequence_group->requires_sampling()) {
const auto request_id = sequence_group->get_request_id();
if (sequence_group_id == 0) {
sequence_group_offsets[request_id] = 0;
last_request_id = request_id;
} else {
sequence_group_offsets[request_id] = sequence_group_offsets[last_request_id] + padded_amount_of_processed_tokens * num_running_sequences;
last_request_id = request_id;
}
if (!m_logit_processors.count(request_id)) {
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();
const auto request_id = sequence_group->get_request_id();
if (!m_logit_processors.count(request_id)) {
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}

const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);

// Call sample_from_sequence_group asynchronously
future_map[sequence_group] = m_thread_pool.submit(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits,
m_logit_processors.at(sequence_group->get_request_id()), is_validation_mode_enabled);
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}
// accumulate a number of processed tokens
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
}

for (auto& sequence_group : sequence_groups) {
if (future_map.find(sequence_group) != future_map.end()) {
// If there is a future assigned to a sequence group we read it's result (blocking if results not available yet)
auto sequence_group_sampling_info = future_map[sequence_group].get();

// Merge sampler output from sequence group to the main one
sampler_output.m_dropped_sequences.insert(
sampler_output.m_dropped_sequences.end(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.begin(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.end()
);

for (const auto& forked_seq : sequence_group_sampling_info.sampler_output.m_forked_sequences) {
sampler_output.m_forked_sequences[forked_seq.first].insert(
sampler_output.m_forked_sequences[forked_seq.first].end(),
forked_seq.second.begin(),
forked_seq.second.end()
);
}
// Parallel sampling execution
tbb::parallel_for(tbb::blocked_range<size_t>(0, sequence_groups.size()), [&](const tbb::blocked_range<size_t>& r) {
for (size_t sequence_group_id = r.begin(); sequence_group_id != r.end(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
if (!sequence_group->is_scheduled())
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled

// NOTE: it should be before 'get_num_scheduled_tokens' is used
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (sequence_group_sampling_info.max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id());
logit_processor.update_generated_len(min_processed_tokens);
if (sequence_group->requires_sampling()) {
const auto request_id = sequence_group->get_request_id();
const void * sequence_group_logits_data = logits_data + vocab_size * sequence_group_offsets[request_id];
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);

// Call sample_from_sequence_group synchronously
auto sequence_group_sampling_info = sample_from_sequence_group(sequence_group, sequence_group_logits,
m_logit_processors.at(request_id), is_validation_mode_enabled);

// Merge sampler output from sequence group to the main one
{
tbb::spin_mutex::scoped_lock lock(sampler_output_mutex);
sampler_output.m_dropped_sequences.insert(
sampler_output.m_dropped_sequences.end(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.begin(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.end()
);

for (const auto& forked_seq : sequence_group_sampling_info.sampler_output.m_forked_sequences) {
sampler_output.m_forked_sequences[forked_seq.first].insert(
sampler_output.m_forked_sequences[forked_seq.first].end(),
forked_seq.second.begin(),
forked_seq.second.end()
);
}
}

// NOTE: it should be before 'get_num_scheduled_tokens' is used
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (sequence_group_sampling_info.max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id());
logit_processor.update_generated_len(min_processed_tokens);
}
} else {
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
}
} else {
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
}
}
});

return sampler_output;
}
Expand Down

0 comments on commit 59a4e6d

Please sign in to comment.