diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 6d9aae30fa..6fdb8ac1cd 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -45,6 +45,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; size_t m_to_remove_from_hist = 0; size_t m_kv_cache_seq_length_axis = 2; + Sampler m_sampler; StatefulLLMPipeline( const ov::InferRequest& request, @@ -75,7 +76,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { const std::string& device, const ov::AnyMap& config, const ov::genai::GenerationConfig& generation_config - ) : LLMPipelineImplBase(tokenizer, generation_config) { + ) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) { ov::Core core; ov::CompiledModel compiled_model; auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config); @@ -96,6 +97,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // If eos_token_id was not provided, take value if (m_generation_config.eos_token_id == -1) m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + + m_sampler.set_seed(m_generation_config.rng_seed); } StatefulLLMPipeline( @@ -358,9 +361,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { requests.push_back(sequence_group); } - Sampler sampler = Sampler(m_tokenizer); + if (m_sampler.get_seed() != config.rng_seed) { + m_sampler.set_seed(config.rng_seed); + } + std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, - sampler, requests, position_ids, std::nullopt, m_selected_beam); + m_sampler, requests, position_ids, std::nullopt, m_selected_beam); } if (is_chat_conversation) { diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 3ab041fa58..62c53cace4 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -247,6 +247,9 @@ std::pair get_lm_encoded_results( // next_selected_beam = sampler.last_selected_beam(request); } + for (SequenceGroup::Ptr sequence_group : sequence_groups) + sampler.clear_request_info(sequence_group->get_request_id()); + return {results, next_selected_beam}; } diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 0f7876cbf9..08a9863e0a 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -55,6 +55,7 @@ class Sampler { std::map m_beam_search_info; std::mt19937 rng_engine; + size_t seed = rng_engine.default_seed; // { request_id, logit_processor } std::map m_logit_processors; @@ -65,7 +66,11 @@ class Sampler { Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); - void set_seed(size_t seed) { rng_engine.seed(seed); } + void set_seed(size_t new_seed) { + rng_engine.seed(new_seed); + seed = new_seed; + } + size_t get_seed() { return seed; } void clear_request_info(uint64_t request_id); diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 0d7aebc506..7bf1c1070a 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -67,6 +67,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { float m_load_time_ms = 0; // Axis num in kv cache from m_language model, which contains information about history len size_t m_kv_cache_seq_length_axis = 2; + // Component for applying sampling to lm outputs + Sampler m_sampler; VLMPipelineImpl( const std::filesystem::path& models_dir, @@ -105,6 +107,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMPipelineImpl( @@ -140,6 +145,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler = Sampler(m_tokenizer); + m_sampler.set_seed(m_generation_config.rng_seed); } VLMDecodedResults generate( @@ -204,11 +212,13 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }}; std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size); - Sampler sampler = Sampler(m_tokenizer); + if (m_sampler.get_seed() != generation_config.rng_seed) { + m_sampler.set_seed(generation_config.rng_seed); + } ov::genai::EncodedResults encoded_result; int32_t m_selected_beam = 0; - std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests, + std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, position_ids, m_embedding, std::nullopt); auto decode_start_time = std::chrono::steady_clock::now(); diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py index 9260e671d6..d9661e538b 100644 --- a/tests/python_tests/test_chat_generate_api.py +++ b/tests/python_tests/test_chat_generate_api.py @@ -187,10 +187,13 @@ def test_set_chat_template(): model_descr = get_chat_models_list()[0] model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}") + config = ov_genai.GenerationConfig() + config.max_new_tokens = 1 + config.do_sample = False pipe.start_chat() - generated = pipe.generate("a", max_new_tokens=1) + generated = pipe.generate("a", config) pipe.finish_chat() - reference = pipe.generate("a", max_new_tokens=1) + reference = pipe.generate("a", config) assert generated == reference prompts = [