Skip to content

Commit

Permalink
Make Sampler a member of the class for llm/vlm pipelines (#1412)
Browse files Browse the repository at this point in the history
cherry-pick #1347
to master
  • Loading branch information
sbalandi authored Dec 20, 2024
1 parent 19c66f5 commit 4d18f8b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 8 deletions.
12 changes: 9 additions & 3 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;
Sampler m_sampler;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -75,7 +76,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
const std::string& device,
const ov::AnyMap& config,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {
) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
ov::Core core;
ov::CompiledModel compiled_model;
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
Expand All @@ -96,6 +97,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1)
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());

m_sampler.set_seed(m_generation_config.rng_seed);
}

StatefulLLMPipeline(
Expand Down Expand Up @@ -358,9 +361,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
requests.push_back(sequence_group);
}

Sampler sampler = Sampler(m_tokenizer);
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}

std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
m_sampler, requests, position_ids, std::nullopt, m_selected_beam);
}

if (is_chat_conversation) {
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
// next_selected_beam = sampler.last_selected_beam(request);
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

return {results, next_selected_beam};
}

Expand Down
7 changes: 6 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Sampler {
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

std::mt19937 rng_engine;
size_t seed = rng_engine.default_seed;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;

Expand All @@ -65,7 +66,11 @@ class Sampler {
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }
void set_seed(size_t new_seed) {
rng_engine.seed(new_seed);
seed = new_seed;
}
size_t get_seed() { return seed; }

void clear_request_info(uint64_t request_id);

Expand Down
14 changes: 12 additions & 2 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
float m_load_time_ms = 0;
// Axis num in kv cache from m_language model, which contains information about history len
size_t m_kv_cache_seq_length_axis = 2;
// Component for applying sampling to lm outputs
Sampler m_sampler;

VLMPipelineImpl(
const std::filesystem::path& models_dir,
Expand Down Expand Up @@ -105,6 +107,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

VLMPipelineImpl(
Expand Down Expand Up @@ -140,6 +145,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

VLMDecodedResults generate(
Expand Down Expand Up @@ -204,11 +212,13 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), history_size);

Sampler sampler = Sampler(m_tokenizer);
if (m_sampler.get_seed() != generation_config.rng_seed) {
m_sampler.set_seed(generation_config.rng_seed);
}

ov::genai::EncodedResults encoded_result;
int32_t m_selected_beam = 0;
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests,
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, m_embedding, std::nullopt);

auto decode_start_time = std::chrono::steady_clock::now();
Expand Down
7 changes: 5 additions & 2 deletions tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,13 @@ def test_set_chat_template():
model_descr = get_chat_models_list()[0]
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}")
config = ov_genai.GenerationConfig()
config.max_new_tokens = 1
config.do_sample = False
pipe.start_chat()
generated = pipe.generate("a", max_new_tokens=1)
generated = pipe.generate("a", config)
pipe.finish_chat()
reference = pipe.generate("a", max_new_tokens=1)
reference = pipe.generate("a", config)
assert generated == reference

prompts = [
Expand Down

0 comments on commit 4d18f8b

Please sign in to comment.