Skip to content

Commit

Permalink
Add sampling to vlm pipeline by Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Oct 8, 2024
1 parent 68e9870 commit 40b4924
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 56 deletions.
11 changes: 9 additions & 2 deletions samples/cpp/visual_language_chat/visual_language_chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ int main(int argc, char* argv[]) try {
}
pipe.generate(
prompt,
ov::genai::image(std::move(image)),
// ov::genai::image(std::move(image)),
ov::genai::generation_config(ov::genai::beam_search()),
// ov::genai::generation_config(ov::genai::greedy()),
// ov::genai::generation_config(ov::genai::multinomial()),
ov::genai::streamer(print_subword)
);
std::cout << "\n----------\n"
"question:\n";
while (std::getline(std::cin, prompt)) {
pipe.generate(prompt, ov::genai::streamer(print_subword));
pipe.generate(prompt,
ov::genai::generation_config(ov::genai::beam_search()),
// ov::genai::generation_config(ov::genai::greedy()),
// ov::genai::generation_config(ov::genai::multinomial()),
ov::genai::streamer(print_subword));
std::cout << "\n----------\n"
"question:\n";
}
Expand Down
26 changes: 26 additions & 0 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,22 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group,
}
}


std::vector<int32_t> Sampler::GroupBeamSearcher::get_beam_idxs() {
std::vector<int32_t> next_beams;

for (Group& group : m_groups) {
if (!group.done) {
for (Beam& beam : group.ongoing) {
next_beams.push_back(beam.m_global_beam_idx);
}
}
}

return next_beams;
}


void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output) {
assert(m_parameters.num_beams % m_parameters.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
Expand Down Expand Up @@ -561,6 +577,16 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
}


std::vector<int32_t> Sampler::get_beam_idxs(uint64_t request_id) {
std::vector<int32_t> beams;
if (m_beam_search_info.find(request_id) != m_beam_search_info.end()) {
GroupBeamSearcher beam_searcher = m_beam_search_info.at(request_id);
std::vector<int32_t> beams = beam_searcher.get_beam_idxs();
}
return beams;
}


SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits) {
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Sampler {
SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits);
void set_seed(size_t seed) { rng_engine.seed(seed); }
void clear_beam_search_info(uint64_t request_id);
std::vector<int32_t> get_beam_idxs(uint64_t request_id);
};

class Sampler::GroupBeamSearcher {
Expand Down Expand Up @@ -105,5 +106,6 @@ class Sampler::GroupBeamSearcher {

void select_next_tokens(const ov::Tensor& logits, SamplerOutput& sampler_output);
void finalize(SamplerOutput& sampler_output);
std::vector<int32_t> get_beam_idxs();
};
}
239 changes: 185 additions & 54 deletions src/cpp/src/vlm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <optional>
#include <random>

#include "sampler.hpp"

#include "debug_utils.hpp"

using namespace ov::genai;

namespace {
Expand Down Expand Up @@ -296,6 +300,163 @@ ov::Tensor resample(VLMPipeline& pipe, const ov::Tensor& encoded_image, const st
}
}


void forward_embedings_and_lm(SequenceGroup::CPtr sequence_group, ov::InferRequest& embedding, ov::InferRequest& language, const VLMConfig m_vlm_config, const std::shared_ptr<Sampler> sampler) {
// compute aggregated values
size_t num_sequences = sequence_group->num_running_seqs();
size_t batch_size_in_sequences = num_sequences;
size_t total_num_tokens = sequence_group->get_num_scheduled_tokens() * num_sequences;
size_t total_num_blocks = sequence_group->get_num_blocks() * num_sequences;
size_t max_context_len_val = std::max(max_context_len_val, sequence_group->get_context_len());

ov::Tensor
input_ids(ov::element::i64, {total_num_tokens, 1}),
position_ids(ov::element::i64, {total_num_tokens, 1}),
beam_idx(ov::element::i32, { total_num_tokens });

// get raw pointers to copy to
int64_t
* input_ids_data = input_ids.data<int64_t>(),
* position_ids_data = position_ids.data<int64_t>();
int32_t
* beam_idx_data = beam_idx.data<int32_t>();

std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
size_t group_context_len = group_position_id;

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
Sequence::CPtr sequence = running_sequences[seq_id];

for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];

position_ids_data[token_id] = position_id;
}

// apply strides to shift to a next sequence
input_ids_data += num_scheduled_tokens;
position_ids_data += num_scheduled_tokens;
}

embedding.set_input_tensor(input_ids);

embedding.infer();
const ov::Tensor& embed_prompt_tensor = embedding.get_output_tensor();
float* embed_data = embed_prompt_tensor.data<float>();
for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) {
embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb;
}

language.set_tensor("inputs_embeds", embed_prompt_tensor);

language.get_tensor("attention_mask").set_shape({ total_num_tokens, language.get_tensor("attention_mask").get_shape()[1] + 1 });
std::fill_n(language.get_tensor("attention_mask").data<int64_t>(), language.get_tensor("attention_mask").get_size(), 1);

language.set_tensor("position_ids", position_ids);
std::vector<int32_t> beam_idxs = sampler->get_beam_idxs(sequence_group->get_request_id());
if (beam_idxs.empty()) {
for (size_t i = 0; i < num_sequences; i++) {
beam_idx_data[i] = 0;
}
} else {
for (size_t i = 0; i < beam_idxs.size(); i++) {
beam_idx_data[i] = beam_idxs.at(i);
}
}
language.set_tensor("beam_idx", beam_idx);

// print_tensor("input_ids", input_ids);
// print_tensor("position_ids", position_ids);
// print_tensor("attention_mask", language.get_tensor("attention_mask"));
// print_tensor("beam_idx", beam_idx);

language.infer();
}


EncodedGenerationResult get_lm_encoded_results(
ov::InferRequest& language,
ov::InferRequest& embedding,
ov::Tensor inputs_embeds,
const VLMConfig m_vlm_config,
const std::shared_ptr<StreamerBase> streamer_ptr,
const std::shared_ptr<Sampler> sampler,
std::vector<SequenceGroup::Ptr> requests
) {
SequenceGroup::Ptr request = requests.back();
GenerationHandle generation = std::make_shared<GenerationHandleImpl>(request->get_generation_stream(), request->get_sampling_parameters());

language.set_tensor("inputs_embeds", inputs_embeds);

size_t history_len = language.get_tensor("attention_mask").get_shape().at(1);
language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]});
std::fill_n(language.get_tensor("attention_mask").data<int64_t>(), language.get_tensor("attention_mask").get_size(), 1);

language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)});
std::iota(language.get_tensor("position_ids").data<int64_t>(), language.get_tensor("position_ids").data<int64_t>() + language.get_tensor("position_ids").get_size(), history_len);

language.get_tensor("beam_idx").set_shape({ BATCH_SIZE });
language.get_tensor("beam_idx").data<int32_t>()[0] = 0;

language.infer();

int64_t sequence_len = language.get_tensor("logits").get_shape().at(1);
request->schedule_tokens(sequence_len);

SamplerOutput sampler_output = sampler->sample(requests, language.get_tensor("logits"));

language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size});
language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 });


while (!request->has_finished()) {
request->schedule_tokens(1);

forward_embedings_and_lm(request, embedding, language, m_vlm_config, sampler);

if (streamer_ptr) {
// first sequences
int64_t out_token = request.get()->operator[](0)->get_generated_ids().back();
if (streamer_ptr->put(out_token)) {
break;
}
}

sampler_output = sampler->sample(requests, language.get_tensor("logits"));
}

if (streamer_ptr) {
streamer_ptr->end();
}

EncodedGenerationResult result;
result.m_request_id = 1;
std::vector<GenerationOutput> generation_outputs = generation->read_all();
std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) {
return r1.score > r2.score;
});

auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
result.m_generation_ids.push_back(std::move(generation_output.generated_ids));
result.m_scores.push_back(generation_output.score);
}
result.m_status = generation->get_status();

return result;
}


class ov::genai::VLMPipeline::VLMPipelineImpl {
};

Expand Down Expand Up @@ -447,31 +608,16 @@ DecodedResults VLMPipeline::generate(
}
}
}
m_language.set_tensor("inputs_embeds", inputs_embeds);
size_t history_len = m_language.get_tensor("attention_mask").get_shape().at(1);
m_language.get_tensor("attention_mask").set_shape({1, history_len + inputs_embeds.get_shape()[1]});
std::fill_n(m_language.get_tensor("attention_mask").data<int64_t>(), m_language.get_tensor("attention_mask").get_size(), 1);
m_language.get_tensor("position_ids").set_shape({1, inputs_embeds.get_shape().at(1)});
std::iota(m_language.get_tensor("position_ids").data<int64_t>(), m_language.get_tensor("position_ids").data<int64_t>() + m_language.get_tensor("position_ids").get_size(), history_len);
m_language.get_tensor("beam_idx").set_shape({ BATCH_SIZE });
m_language.get_tensor("beam_idx").data<int32_t>()[0] = 0;

std::shared_ptr<Sampler> sampler = std::make_shared<Sampler>(m_tokenizer);

m_language.infer();
std::vector<SequenceGroup::Ptr> requests;
auto attention_size = m_language.get_tensor("attention_mask").get_size(); // request_id, input_ids, generation_config, block_size, enable_prefix_caching
// now we have one prompt as input, so we need one request
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(0, encoded_input, generation_config, 1, false);
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);

ov::Shape logits_shape = m_language.get_tensor("logits").get_shape();
auto attention_size = m_language.get_tensor("attention_mask").get_size();

int64_t sequence_len = m_language.get_tensor("logits").get_shape().at(1) - 1;
size_t vocab_size = m_language.get_tensor("logits").get_shape().back();
float* logits = m_language.get_tensor("logits").data<float>() + sequence_len * vocab_size;
int64_t out_token = std::max_element(logits, logits + vocab_size) - logits;

m_language.get_tensor("inputs_embeds").set_shape({BATCH_SIZE, 1, m_vlm_config.hidden_size});
m_language.get_tensor("position_ids").set_shape({ BATCH_SIZE, 1 });

m_embedding.get_input_tensor().set_shape({ 1, 1 });

int64_t eos_token_id = m_tokenizer.get_eos_token_id();
std::shared_ptr<StreamerBase> streamer_ptr = std::visit(overloaded{
[&m_tokenizer = m_tokenizer](
const std::function<bool(std::string)>& callback
Expand All @@ -485,40 +631,16 @@ DecodedResults VLMPipeline::generate(
return std::shared_ptr<StreamerBase>{nullptr};
},
}, streamer);
std::vector<int64_t> generated;
while (true) { //(out_token != eos_token_id)
m_embedding.get_input_tensor().data<int64_t>()[0] = out_token;
m_embedding.infer();
const ov::Tensor& embed_prompt_tensor = m_embedding.get_output_tensor();
float* embed_data = embed_prompt_tensor.data<float>();
for (auto idx = 0; idx < embed_prompt_tensor.get_size(); idx++) {
embed_data[idx] = embed_data[idx] * m_vlm_config.scale_emb;
}

m_language.set_tensor("inputs_embeds", embed_prompt_tensor);
m_language.get_tensor("attention_mask").set_shape({ BATCH_SIZE, m_language.get_tensor("attention_mask").get_shape()[1] + 1 });
std::fill_n(m_language.get_tensor("attention_mask").data<int64_t>(), m_language.get_tensor("attention_mask").get_size(), 1);
m_language.get_tensor("position_ids").data<int64_t>()[0] = int64_t(m_language.get_tensor("attention_mask").get_size() - 2);
EncodedGenerationResult encoded_result = get_lm_encoded_results(m_language, m_embedding, inputs_embeds, m_vlm_config, streamer_ptr, sampler, requests);

m_language.infer();

generated.push_back(out_token);
if (streamer_ptr && streamer_ptr->put(out_token)) {
break;
}
logits = m_language.get_tensor("logits").data<float>();

out_token = std::max_element(logits, logits + vocab_size) - logits;
if (out_token == eos_token_id) {
break;
}
}

if (streamer_ptr) {
streamer_ptr->end();
DecodedResults decoded;
for (size_t idx = 0; idx < encoded_result.m_generation_ids.size(); ++idx) {
decoded.texts.push_back(m_tokenizer.decode(encoded_result.m_generation_ids.at(idx)));
decoded.scores.push_back(encoded_result.m_scores.at(idx));
}

std::string decoded_results = m_tokenizer.decode(generated);
std::string decoded_results = decoded.texts.at(0);
if (m_is_chat_conversation) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
Expand All @@ -530,7 +652,7 @@ DecodedResults VLMPipeline::generate(
}
m_language.get_tensor("attention_mask").set_shape({1, 0});
}
return {{std::move(decoded_results)}};
return decoded;
}

DecodedResults VLMPipeline::generate(
Expand All @@ -552,6 +674,15 @@ DecodedResults VLMPipeline::generate(
ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map);
GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config();
config.update_generation_config(config_map);

// If eos_token_id was not provided, take value
if (config.eos_token_id == -1)
config.set_eos_token_id(m_tokenizer.get_eos_token_id());

// if (is_chat_conversation && config.num_return_sequences > 1) {
// config.num_return_sequences = 1;
// }

return generate(
prompt,
rgbs,
Expand Down

0 comments on commit 40b4924

Please sign in to comment.