Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Sampler a member of the class for llm/vlm pipelines #1412

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;
Sampler m_sampler;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -75,7 +76,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
const std::string& device,
const ov::AnyMap& config,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {
) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
ov::Core core;
ov::CompiledModel compiled_model;
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
Expand All @@ -96,6 +97,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1)
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());

m_sampler.set_seed(m_generation_config.rng_seed);
}

StatefulLLMPipeline(
Expand Down Expand Up @@ -358,9 +361,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
requests.push_back(sequence_group);
}

Sampler sampler = Sampler(m_tokenizer);
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}

std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
m_sampler, requests, position_ids, std::nullopt, m_selected_beam);
}

if (is_chat_conversation) {
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
// next_selected_beam = sampler.last_selected_beam(request);
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

return {results, next_selected_beam};
}

Expand Down
7 changes: 6 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Sampler {
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

std::mt19937 rng_engine;
size_t seed = rng_engine.default_seed;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;

Expand All @@ -65,7 +66,11 @@ class Sampler {
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }
void set_seed(size_t new_seed) {
rng_engine.seed(new_seed);
seed = new_seed;
}
size_t get_seed() { return seed; }

void clear_request_info(uint64_t request_id);

Expand Down
14 changes: 12 additions & 2 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
float m_load_time_ms = 0;
// Axis num in kv cache from m_language model, which contains information about history len
size_t m_kv_cache_seq_length_axis = 2;
// Component for applying sampling to lm outputs
Sampler m_sampler;

VLMPipelineImpl(
const std::filesystem::path& models_dir,
Expand Down Expand Up @@ -105,6 +107,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

VLMPipelineImpl(
Expand Down Expand Up @@ -140,6 +145,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

VLMDecodedResults generate(
Expand Down Expand Up @@ -204,11 +212,13 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), history_size);

Sampler sampler = Sampler(m_tokenizer);
if (m_sampler.get_seed() != generation_config.rng_seed) {
m_sampler.set_seed(generation_config.rng_seed);
}

ov::genai::EncodedResults encoded_result;
int32_t m_selected_beam = 0;
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests,
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, m_embedding, std::nullopt);

auto decode_start_time = std::chrono::steady_clock::now();
Expand Down
7 changes: 5 additions & 2 deletions tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,13 @@ def test_set_chat_template():
model_descr = get_chat_models_list()[0]
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}")
config = ov_genai.GenerationConfig()
config.max_new_tokens = 1
config.do_sample = False
pipe.start_chat()
generated = pipe.generate("a", max_new_tokens=1)
generated = pipe.generate("a", config)
pipe.finish_chat()
reference = pipe.generate("a", max_new_tokens=1)
reference = pipe.generate("a", config)
assert generated == reference

prompts = [
Expand Down
Loading