diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 1842d205cd..5958dafe33 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -259,7 +259,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template + python -m pytest -v ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template env: PYTHONPATH: "./build/:$PYTHONPATH" @@ -267,11 +267,11 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels - python -m pytest ./tests/python_tests --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" + python -m pytest -v ./tests/python_tests --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" - run: > source ${OV_INSTALL_DIR}/setupvars.sh - && python -m pytest ./tests/python_tests/test_vlm_api.py + && python -m pytest -v ./tests/python_tests/test_vlm_api.py genai_python_lib_whisper: name: OpenVINO genai extension whisper tests (cmake + wheel) @@ -350,7 +350,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k test_smoke env: PYTHONPATH: "./build/:$PYTHONPATH" @@ -358,7 +358,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" genai_package: name: OpenVINO genai extension (install to OpenVINO package) diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/llm_bench-python.yml index f8940d67bb..0be7be8153 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/llm_bench-python.yml @@ -82,6 +82,12 @@ jobs: run: | wget -O ./ov_models/soulcard.safetensors https://civitai.com/api/download/models/72591 python ./tools/llm_bench/benchmark.py -m ./ov_models/dreamlike-art-dreamlike-anime-1.0/FP16/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --genai --lora ./ov_models/soulcard.safetensors --lora_alphas 0.7 + - name: Test TinyLlama-1.1B-Chat-v1.0 in Speculative Deconding mode on Linux + run: | + optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format fp16 ov_models/TinyLlama-1.1B-Chat-v1.0/FP16 + optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format int8 ov_models/TinyLlama-1.1B-Chat-v1.0/INT8 + python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --assistant_confidence_threshold 0.4 + python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --genai --num_assistant_tokens 5 - name: Test whisper-tiny on Linux run: | GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 --branch main --single-branch https://huggingface.co/datasets/facebook/multilingual_librispeech @@ -99,7 +105,7 @@ jobs: pip install git+https://github.com/huggingface/optimum.git GIT_CLONE_PROTECTION_ACTIVE=false pip install ${{ env.WWB_PATH }} python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly --force-reinstall - python -m pytest tools/who_what_benchmark/tests + python -m pytest -v tools/who_what_benchmark/tests stateful: runs-on: ubuntu-20.04 steps: @@ -121,4 +127,4 @@ jobs: GIT_CLONE_PROTECTION_ACTIVE=false pip install tools/who_what_benchmark/ pip install pytest python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly --force-reinstall - python -m pytest tools/who_what_benchmark/tests + python -m pytest -v tools/who_what_benchmark/tests diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 980c689e19..25da21b209 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -225,7 +225,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template + python -m pytest -v ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template env: PYTHONPATH: "./build/:$PYTHONPATH" @@ -234,7 +234,7 @@ jobs: source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels python -c "from openvino_genai import LLMPipeline" - python -m pytest ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" + python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" genai_python_lib_whisper: name: OpenVINO genai extension whisper tests (cmake + wheel) @@ -289,7 +289,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k test_smoke env: PYTHONPATH: "./build/:$PYTHONPATH" @@ -298,7 +298,7 @@ jobs: source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels python -c "from openvino_genai import LLMPipeline" - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" genai_package: name: OpenVINO genai extension (install to OpenVINO package) diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 5a86965e3f..33096d6d7b 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -236,7 +236,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template + python -m pytest -v ./tests/python_tests/test_chat_generate_api.py::test_set_chat_template env: PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. @@ -244,7 +244,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install . --verbose - python -m pytest ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" + python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" genai_python_lib_whisper: name: OpenVINO genai extension whisper tests (cmake + wheel) @@ -300,7 +300,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k test_smoke env: PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. @@ -308,7 +308,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install . --verbose - python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" + python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" genai_python_lib_vlm: name: OpenVINO genai VLM tests (cmake + wheel) @@ -364,7 +364,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels --upgrade-strategy eager - python -m pytest ./tests/python_tests/test_vlm_api.py + python -m pytest -v ./tests/python_tests/test_vlm_api.py env: PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. diff --git a/.gitignore b/.gitignore index 729767a485..c3e87cbb7d 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ temp/ .repo/ CMakeLists.txt.user CMakeUserPresets.json +.env *.project *.cproject diff --git a/README.md b/README.md index 3073c4f9f4..fe18205028 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,11 @@ For more examples check out our [LLM Inference Guide](https://docs.openvino.ai/2 ### Converting and compressing the model from Hugging Face library ```sh -optimum-cli export openvino --model openbmb/MiniCPM-V-2_6 --trust-remote-code MiniCPM-V-2_6 +#(Basic) download and convert to OpenVINO MiniCPM-V-2_6 model +optimum-cli export openvino --model openbmb/MiniCPM-V-2_6 --trust-remote-code --weight-format fp16 MiniCPM-V-2_6 + +#(Recommended) Same as above but with compression: language model is compressed to int4, other model components are compressed to int8 +optimum-cli export openvino --model openbmb/MiniCPM-V-2_6 --trust-remote-code --weight-format int4 MiniCPM-V-2_6 ``` ### Run generation using VLMPipeline API in Python @@ -159,6 +163,9 @@ For more examples check out our [LLM Inference Guide](https://docs.openvino.ai/2 ```sh #Download and convert to OpenVINO dreamlike-anime-1.0 model optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format fp16 dreamlike_anime_1_0_ov/FP16 + +#You can also use INT8 hybrid quantization to further optimize the model and reduce inference latency +optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format int8 --dataset conceptual_captions dreamlike_anime_1_0_ov/INT8 ``` ### Run generation using Text2Image API in Python diff --git a/samples/python/text2image/README.md b/samples/python/text2image/README.md index d8dc23d0fa..675d39d9a5 100644 --- a/samples/python/text2image/README.md +++ b/samples/python/text2image/README.md @@ -63,6 +63,6 @@ With adapter | Without adapter ![](./lora.bmp) | ![](./baseline.bmp) -# Fuse LoRA adapters into model weights +## Fuse LoRA adapters into model weights -To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`. \ No newline at end of file +To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`. diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 8b8903d0f3..916167b63b 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -17,10 +17,11 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( const std::string& device, const ov::AnyMap& properties) { m_tokenizer = tokenizer; + m_generation_config = utils::from_config_json_if_exists(models_path); ov::Core core; - auto [core_properties, compile_properties] = ov::genai::utils::split_core_complile_config(properties); + auto [core_properties, compile_properties] = utils::split_core_complile_config(properties); core.set_property(core_properties); // The model can be compiled for GPU as well @@ -74,6 +75,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init( m_model_runner = std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction); m_sampler = std::make_shared(m_tokenizer); m_sampler->set_seed(m_generation_config.rng_seed); + + // If eos_token_id was not provided, take value + if (m_generation_config.eos_token_id == -1) + m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); }; @@ -81,8 +86,11 @@ GenerationHandle ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) { - sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id()); + // If eos_token_id was not provided, take value from default m_generation_config + if (sampling_params.eos_token_id == -1) + sampling_params.set_eos_token_id(m_generation_config.eos_token_id); sampling_params.validate(); + SequenceGroup::Ptr sequence_group = std::make_shared(request_id, input_ids, sampling_params, m_scheduler->get_block_size(), @@ -262,6 +270,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); @@ -283,7 +294,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually excluded in `GenerationConfig`"); + OPENVINO_ASSERT(num_assistant_tokens > 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually exclusive in `GenerationConfig`"); }; } } @@ -208,6 +208,5 @@ GenerationConfig multinomial() { return multinomial_config; } - } // namespace genai } // namespace ov diff --git a/src/cpp/src/image_generation/diffusion_pipeline.hpp b/src/cpp/src/image_generation/diffusion_pipeline.hpp index 2213f3711e..af7459a2da 100644 --- a/src/cpp/src/image_generation/diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/diffusion_pipeline.hpp @@ -36,6 +36,27 @@ const std::string get_class_name(const std::filesystem::path& root_dir) { return data["_class_name"].get(); } +ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) { + float w = guidance_scale * 1000; + uint32_t half_dim = embedding_dim / 2; + float emb = std::log(10000) / (half_dim - 1); + + ov::Shape embedding_shape = {1, embedding_dim}; + ov::Tensor w_embedding(ov::element::f32, embedding_shape); + float* w_embedding_data = w_embedding.data(); + + for (size_t i = 0; i < half_dim; ++i) { + float temp = std::exp((i * (-emb))) * w; + w_embedding_data[i] = std::sin(temp); + w_embedding_data[i + half_dim] = std::cos(temp); + } + + if (embedding_dim % 2 == 1) + w_embedding_data[embedding_dim - 1] = 0; + + return w_embedding; +} + } // namespace diff --git a/src/cpp/src/image_generation/schedulers/lcm.cpp b/src/cpp/src/image_generation/schedulers/lcm.cpp index 7d22520639..d3afcd6300 100644 --- a/src/cpp/src/image_generation/schedulers/lcm.cpp +++ b/src/cpp/src/image_generation/schedulers/lcm.cpp @@ -99,7 +99,7 @@ void LCMScheduler::set_timesteps(size_t num_inference_steps, float strength) { assert(skipping_step >= 1 && "The combination of `original_steps x strength` is smaller than `num_inference_steps`"); // LCM Inference Steps Schedule - std::reverse(lcm_origin_timesteps.begin(),lcm_origin_timesteps.end()); + std::reverse(lcm_origin_timesteps.begin(), lcm_origin_timesteps.end()); using numpy_utils::linspace; // v1. based on https://github.com/huggingface/diffusers/blame/2a7f43a73bda387385a47a15d7b6fe9be9c65eb2/src/diffusers/schedulers/scheduling_lcm.py#L387 diff --git a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp index 92367d82a2..b8517a1476 100644 --- a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp @@ -18,31 +18,6 @@ namespace ov { namespace genai { -namespace { - -ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) { - float w = guidance_scale * 1000; - uint32_t half_dim = embedding_dim / 2; - float emb = std::log(10000) / (half_dim - 1); - - ov::Shape embedding_shape = {1, embedding_dim}; - ov::Tensor w_embedding(ov::element::f32, embedding_shape); - float* w_embedding_data = w_embedding.data(); - - for (size_t i = 0; i < half_dim; ++i) { - float temp = std::exp((i * (-emb))) * w; - w_embedding_data[i] = std::sin(temp); - w_embedding_data[i + half_dim] = std::cos(temp); - } - - if (embedding_dim % 2 == 1) - w_embedding_data[embedding_dim - 1] = 0; - - return w_embedding; -} - -} // namespace - class StableDiffusionPipeline : public DiffusionPipeline { public: StableDiffusionPipeline(PipelineType pipeline_type, const std::filesystem::path& root_dir) : @@ -148,7 +123,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { check_image_size(height, width); - const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings); m_vae->reshape(num_images_per_prompt, height, width); @@ -203,7 +178,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { // see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline const auto& unet_config = m_unet->get_config(); - const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); if (generation_config.height < 0) @@ -245,8 +220,8 @@ class StableDiffusionPipeline : public DiffusionPipeline { } if (unet_config.time_cond_proj_dim >= 0) { // LCM - ov::Tensor guidance_scale_embedding = get_guidance_scale_embedding(generation_config.guidance_scale, unet_config.time_cond_proj_dim); - m_unet->set_hidden_states("timestep_cond", guidance_scale_embedding); + ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim); + m_unet->set_hidden_states("timestep_cond", timestep_cond); } m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); @@ -304,10 +279,6 @@ class StableDiffusionPipeline : public DiffusionPipeline { } private: - bool do_classifier_free_guidance(float guidance_scale) const { - return m_unet->do_classifier_free_guidance(guidance_scale); - } - void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); assert(m_vae != nullptr); @@ -341,7 +312,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override { check_image_size(generation_config.width, generation_config.height); - const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale); const bool is_lcm = m_unet->get_config().time_cond_proj_dim > 0; const char * const pipeline_name = is_lcm ? "Latent Consistency Model" : "Stable Diffusion"; diff --git a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp index 42ee49a19d..b709c58f47 100644 --- a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp @@ -152,7 +152,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { check_image_size(height, width); - const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); m_clip_text_encoder_with_projection->reshape(batch_size_multiplier); m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings); @@ -201,7 +201,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { // see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline const auto& unet_config = m_unet->get_config(); - const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); if (generation_config.height < 0) @@ -376,6 +376,11 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { m_unet->set_hidden_states("time_ids", add_time_ids_repeated); } + if (unet_config.time_cond_proj_dim >= 0) { // LCM + ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim); + m_unet->set_hidden_states("timestep_cond", timestep_cond); + } + m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); std::vector timesteps = m_scheduler->get_timesteps(); @@ -430,10 +435,6 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { } private: - bool do_classifier_free_guidance(float guidance_scale) const { - return guidance_scale > 1.0f && m_unet->get_config().time_cond_proj_dim < 0; - } - void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); assert(m_vae != nullptr); @@ -463,7 +464,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override { check_image_size(generation_config.width, generation_config.height); - const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale); const char * const pipeline_name = "Stable Diffusion XL"; OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name); diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 26221fd5c7..62a72b1cbd 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -195,7 +195,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // If eos_token_id was not provided, take value from default m_generation_config if (config.eos_token_id == -1) - config.eos_token_id = m_generation_config.eos_token_id; + config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); // Stateful pipeline does not provide logprobs for prompt tokens @@ -269,18 +269,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { SequenceGroup::Ptr sequence_group; if (is_chat_conversation && !m_is_cache_empty) { sequence_group = std::make_shared(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching); - sequence_group->update_processed_tokens_num(m_tokenized_chat_history.input_ids.get_shape().at(1) - 1); } else { size_t seq_len = input_ids.get_shape().at(1); size_t batch_offset = request_id * seq_len; const int64_t* prompt_start = input_ids.data() + batch_offset; std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - // in case of multi batch scenario, remove eos_token_id at start of prompt - auto real_prompt_start = std::find_if(tokenized_prompt.begin(), tokenized_prompt.end(), [&config](int64_t token) { return token != config.eos_token_id; }); - tokenized_prompt.erase(tokenized_prompt.begin(), real_prompt_start); sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); - sequence_group->update_processed_tokens_num(tokenized_prompt.size() - 1); } sequence_group->set_sequence_group_ptr(sequence_group); @@ -433,8 +428,9 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { tokenizer, scheduler_config, device, - plugin_config - } {} + plugin_config} { + m_generation_config = m_impl.get_config(); + } ContinuousBatchingAdapter( const std::filesystem::path& models_path, @@ -446,8 +442,9 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { m_tokenizer, scheduler_config, device, - plugin_config - } {} + plugin_config} { + m_generation_config = m_impl.get_config(); + } DecodedResults generate( StringInputs inputs, @@ -622,7 +619,7 @@ void ov::genai::LLMPipeline::set_generation_config(const GenerationConfig& confi m_pimpl->m_generation_config = config; // if eos_token_id was not provided in config forward from default config if (config.eos_token_id == -1) - m_pimpl->m_generation_config.eos_token_id = default_eos_token_id; + m_pimpl->m_generation_config.set_eos_token_id(default_eos_token_id); m_pimpl->m_generation_config.validate(); } diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 8989be3006..40089384a8 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -633,6 +633,11 @@ StaticLLMPipeline::StaticLLMPipeline( } // Initialize tensors prepare_for_new_conversation(); + + // If eos_token_id was not provided, take value + if (m_generation_config.eos_token_id == -1) { + m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + } }; StaticLLMPipeline::StaticLLMPipeline( diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 644aa369c6..c76d9f7edf 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -106,9 +106,12 @@ std::pair get_lm_encoded_results( auto logits = m_llm.get_tensor("logits"); int64_t sequence_len = logits.get_shape().at(1); - for (auto& sequence_group : sequence_groups) + for (auto& sequence_group : sequence_groups) { + sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len); sequence_group->schedule_tokens(sequence_len); + } + std::map beam_offets; for (size_t i = 0; i < sequence_groups.size(); i++) beam_offets.insert({sequence_groups.at(i)->get_request_id(), i}); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 4efb0e6c9a..3febadf112 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -566,20 +566,20 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen running_sequence->set_status(SequenceStatus::FINISHED); if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { - running_sequence->set_finish_reason(GenerationFinishReason::STOP); - } else if (sampling_params.max_new_tokens == generated_len) { - running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); + running_sequence->set_finish_reason(GenerationFinishReason::STOP); + } else if (sampling_params.max_new_tokens == generated_len) { + running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); + } + + dropped_seq_ids.push_back(running_sequence->get_id()); + continue; } - - dropped_seq_ids.push_back(running_sequence->get_id()); - continue; - } - if (!sampling_params.stop_strings.empty()) { - int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); - if (num_matched_last_tokens) { - if (!sampling_params.include_stop_str_in_output) - running_sequence->remove_last_tokens(num_matched_last_tokens); + if (!sampling_params.stop_strings.empty()) { + int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); + if (num_matched_last_tokens) { + if (!sampling_params.include_stop_str_in_output) + running_sequence->remove_last_tokens(num_matched_last_tokens); running_sequence->set_status(SequenceStatus::FINISHED); running_sequence->set_finish_reason(GenerationFinishReason::STOP); dropped_seq_ids.push_back(running_sequence->get_id()); diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 3ac020ccab..9e4840ecfd 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -8,14 +8,16 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::Contin ov::Core& core, const std::shared_ptr& model, const Tokenizer& tokenizer, + const GenerationConfig& generation_config, const DeviceConfig& device_config, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config, bool is_validation_mode_enabled) { m_tokenizer = tokenizer; + m_generation_config = generation_config; m_is_validation_mode_enabled = is_validation_mode_enabled; - init(model, scheduler_config, plugin_config, device_config, core); + init(model, scheduler_config, plugin_config, device_config, core); } void diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index 0040708b4b..682448ed16 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -16,6 +16,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : ContinuousBatchingForSpeculativeDecodingImpl(ov::Core& core, const std::shared_ptr& model, const Tokenizer& tokenizer, + const GenerationConfig& generation_config, const DeviceConfig& device_config, const SchedulerConfig& scheduler_config, const std::string& device, diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index d6d485237c..0f43555a5f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -85,8 +85,12 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl( m_tokenizer = main_model_tokenizer; // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode - m_main_pipeline = std::make_shared(core, main_model, main_model_tokenizer, main_device_config, main_scheduler_config, main_device, compile_properties, true); - m_draft_pipeline = std::make_shared(core, draft_model, draft_model_tokenizer, draft_device_config, draft_scheduler_config, draft_device, draft_properties, false); + m_main_pipeline = std::make_shared(core, + main_model, main_model_tokenizer, utils::from_config_json_if_exists(main_models_path), + main_device_config, main_scheduler_config, main_device, compile_properties, true); + m_draft_pipeline = std::make_shared(core, + draft_model, draft_model_tokenizer, utils::from_config_json_if_exists(draft_models_path), + draft_device_config, draft_scheduler_config, draft_device, draft_properties, false); } GenerationHandle @@ -182,7 +186,7 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { std::vector ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector& input_ids, const std::vector& sampling_params, - const StreamerVariant& streamer) { + const StreamerVariant& streamer) { ManualTimer generate_timer("speculative_decoding: generate()"); generate_timer.start(); OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); @@ -199,6 +203,9 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } }, streamer); + OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); + std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index 9851d6c0cb..f52417a94e 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -396,7 +396,7 @@ class Tokenizer::TokenizerImpl { return std::vector(res_data, res_data + res.get_shape()[0]); } - std::string patch_chat_template(std::string template_str) { + std::string patch_chat_template(std::string template_str) const { // Replace what jinja2cpp doesn't support std::pair replace_str_map[] = { {"'}", "' }"}, @@ -430,10 +430,8 @@ class Tokenizer::TokenizerImpl { if (!file.is_open()) return ""; - std::string res = ""; + std::string res; ov::genai::utils::read_json_param(nlohmann::json::parse(file), "chat_template", res); - if (res.empty()) - return res; return patch_chat_template(res); } @@ -441,7 +439,7 @@ class Tokenizer::TokenizerImpl { std::string apply_chat_template(ChatHistory history, bool add_generation_prompt, const std::string& chat_template) const { - auto chat_tpl = chat_template.empty() ? m_chat_template : chat_template; + std::string chat_tpl = chat_template.empty() ? m_chat_template : patch_chat_template(chat_template); OPENVINO_ASSERT(!chat_tpl.empty(), "Chat template wasn't found. This may indicate that the model wasn't trained for chat scenario." " Please add 'chat_template' to tokenizer_config.json to use the model in chat scenario." diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 559692866c..28077f3ece 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -57,7 +57,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { const ov::AnyMap& properties ) : m_vlm_config{ - utils::from_config_json_if_exists( + utils::from_config_json_if_exists( models_dir, "config.json" ) }, @@ -73,6 +73,11 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ).create_infer_request(); m_language.get_tensor("attention_mask").set_shape({1, 0}); + + // If eos_token_id was not provided, take value + if (m_generation_config.eos_token_id == -1) { + m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); + } } DecodedResults generate( @@ -81,10 +86,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { GenerationConfig generation_config, const StreamerVariant& streamer ) { - // If eos_token_id was not provided, take value - if (generation_config.eos_token_id == -1) { - generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); - } + // If eos_token_id was not provided, take value from default m_generation_config + if (generation_config.eos_token_id == -1) + generation_config.set_eos_token_id(m_generation_config.eos_token_id); + generation_config.validate(); ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs); @@ -100,7 +105,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { std::fill_n(prompt_ids.data(), prompt_ids.get_size(), 0); SequenceGroup::Ptr sequence_group = std::make_shared(request_id, prompt_ids, generation_config, block_size, enable_prefix_caching); - sequence_group->update_processed_tokens_num(history_size); sequence_group->set_sequence_group_ptr(sequence_group); requests.push_back(sequence_group); diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp index 5c10f6df2a..dde2b89291 100644 --- a/src/cpp/src/visual_language/vision_encoder.cpp +++ b/src/cpp/src/visual_language/vision_encoder.cpp @@ -613,7 +613,7 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const VLMModelType model_type, const std::string& device, const ov::AnyMap device_config, ov::Core core) : model_type(model_type) { m_vision_encoder = core.compile_model(model_dir / "openvino_vision_embeddings_model.xml", device, device_config).create_infer_request(); - m_processor_config = ov::genai::utils::from_config_json_if_exists( + m_processor_config = utils::from_config_json_if_exists( model_dir, "preprocessor_config.json" ); } diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index a8e34b9952..c0e486018a 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -158,7 +158,7 @@ void ov::genai::WhisperPipeline::set_generation_config(const WhisperGenerationCo m_impl->m_generation_config = config; // if eos_token_id was not provided in config forward from default config if (config.eos_token_id == -1) - m_impl->m_generation_config.eos_token_id = default_eos_token_id; + m_impl->m_generation_config.set_eos_token_id(default_eos_token_id); m_impl->m_generation_config.validate(); } diff --git a/src/cpp/src/whisper_pipeline_base.hpp b/src/cpp/src/whisper_pipeline_base.hpp index 46a3a13e65..03e6be6e29 100644 --- a/src/cpp/src/whisper_pipeline_base.hpp +++ b/src/cpp/src/whisper_pipeline_base.hpp @@ -7,16 +7,8 @@ #include "whisper/whisper_config.hpp" #include "whisper/whisper_feature_extractor.hpp" -namespace { -ov::genai::WhisperGenerationConfig from_config_json_if_exists(const std::filesystem::path& model_path) { - auto config_file_path = model_path / "generation_config.json"; - if (std::filesystem::exists(config_file_path)) { - return ov::genai::WhisperGenerationConfig((config_file_path).string()); - } else { - return ov::genai::WhisperGenerationConfig{}; - } -} -} // namespace +#include "utils.hpp" + namespace ov { namespace genai { @@ -31,10 +23,10 @@ class WhisperPipeline::WhisperPipelineImplBase { float m_load_time_ms = 0; WhisperPipelineImplBase(const std::filesystem::path& models_path) - : m_generation_config(from_config_json_if_exists(models_path)), + : m_generation_config(utils::from_config_json_if_exists(models_path)), m_tokenizer{models_path}, - m_feature_extractor{(models_path / "preprocessor_config.json").string()}, - m_model_config{(models_path / "config.json").string()} {} + m_feature_extractor{models_path / "preprocessor_config.json"}, + m_model_config{models_path / "config.json"} {} virtual WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, OptionalWhisperGenerationConfig generation_config, diff --git a/tests/python_tests/data/short_prompts.txt b/tests/python_tests/data/short_prompts.txt index d919f62474..ac2bc0009f 100644 --- a/tests/python_tests/data/short_prompts.txt +++ b/tests/python_tests/data/short_prompts.txt @@ -14,17 +14,3 @@ Bananas are berries, while strawberries are not. The Pacific Ocean is the largest ocean on Earth. Sound travels faster in water than in air. The Eiffel Tower can be 15 cm taller during the summer. -Cheetahs are the fastest land animals, reaching speeds up to 75 mph. -The longest river in the world is the Nile River. -Penguins are flightless birds that live in the Southern Hemisphere. -Mars has the largest volcano in the solar system, Olympus Mons. -Diamonds are made of carbon atoms arranged in a crystal structure. -A day on Venus is longer than a year on Venus. -The heart beats about 100,000 times a day. -Octopuses have three hearts and blue blood. -Avocados are toxic to some animals, including dogs. -The mitochondria are known as the powerhouse of the cell. -An octet rule states that atoms are most stable when they have eight electrons in their outer shell. -The Sahara Desert is the largest hot desert in the world. -Lightning is hotter than the surface of the sun. -Honeybees communicate through a dance known as the waggle dance. diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 49cb04ca1f..a34e604382 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -68,36 +68,36 @@ class CacheOptTestStruct: max_cache_usage_optimization_ratio: float -SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=128, aggregation_mode=AggregationMode.NORM_SUM) +SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=96, aggregation_mode=AggregationMode.NORM_SUM) @pytest.mark.precommit @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="doesn't work on win due to optimum-intel export bug, segfault on mac") @pytest.mark.parametrize("test_struct", [ # prompts + generation length are longer than the eviction arena, eviction expected w/ impact to similarity - CacheOptTestStruct(prompt_file="long_prompts.txt", max_new_tokens=128, num_kv_blocks=100, use_cache_eviction=True, + CacheOptTestStruct(prompt_file="long_prompts.txt", max_new_tokens=128, num_kv_blocks=1000, use_cache_eviction=True, cache_eviction_config=SHORT_CACHE_EVICTION_CONFIG, similarity_threshold=0.8, - max_cache_usage_optimization_ratio=1.8, - avg_cache_usage_optimization_ratio=1.35), + max_cache_usage_optimization_ratio=2.0, + avg_cache_usage_optimization_ratio=1.7), # prompts + generation length are shorter than the eviction arena, no eviction expected - CacheOptTestStruct(prompt_file="short_prompts.txt", max_new_tokens=32, num_kv_blocks=100, use_cache_eviction=True, + CacheOptTestStruct(prompt_file="short_prompts.txt", max_new_tokens=32, num_kv_blocks=1000, use_cache_eviction=True, cache_eviction_config=SHORT_CACHE_EVICTION_CONFIG, similarity_threshold=0.98, max_cache_usage_optimization_ratio=0.95, # no improvement expected avg_cache_usage_optimization_ratio=0.95), # short prompts, long generation - eviction expected - CacheOptTestStruct(prompt_file="short_prompts.txt", max_new_tokens=384, num_kv_blocks=100, use_cache_eviction=True, + CacheOptTestStruct(prompt_file="short_prompts.txt", max_new_tokens=160, num_kv_blocks=1000, use_cache_eviction=True, cache_eviction_config=SHORT_CACHE_EVICTION_CONFIG, similarity_threshold=0.94, - max_cache_usage_optimization_ratio=1.75, - avg_cache_usage_optimization_ratio=1.35), + max_cache_usage_optimization_ratio=1.4, + avg_cache_usage_optimization_ratio=1.1), ]) @pytest.mark.parametrize("enable_prefix_caching", [True, False]) # prefix caching shouldn't impact similarity def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, test_struct, enable_prefix_caching): - seqs_per_request = 5 + seqs_per_request = 32 scheduler_config = get_scheduler_config(test_struct.num_kv_blocks) generation_config = GenerationConfig() # expecting default greedy sampling diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index 2f80857359..ba934e3bda 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -85,9 +85,10 @@ def run_hf_ov_genai_comparison(model_descr, generation_config: Dict, prompt: str generation_config_hf['early_stopping'] = STOP_CRITERIA_MAP[generation_config_hf.pop('stop_criteria')] generation_config_hf.pop('ignore_eos', None) - encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True) - hf_encoded_output = model.generate(encoded_prompt, **generation_config_hf) - hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:], skip_special_tokens=True) + encoded_prompt = tokenizer([prompt], return_tensors='pt', add_special_tokens=True) + prompt_ids, attention_mask = encoded_prompt['input_ids'], encoded_prompt['attention_mask'] + hf_encoded_output = model.generate(prompt_ids, attention_mask=attention_mask, **generation_config_hf) + hf_output = tokenizer.decode(hf_encoded_output[0, prompt_ids.shape[1]:], skip_special_tokens=True) ov_output = pipe.generate(prompt, **config) if config.get('num_return_sequences', 1) > 1: @@ -179,12 +180,6 @@ def test_ov_tensors(model_descr, inputs): @pytest.mark.parametrize("prompt", prompts) @pytest.mark.precommit @pytest.mark.nightly -@pytest.mark.xfail( - raises=TypeError, - reason="pybind was unable to find ov::Tensor from openvino yet", - strict=False, - condition=sys.platform in ["linux", "win32"] -) def test_genai_tokenizer_encode(model_descr, prompt): model_id, path, tokenizer, model, pipe = read_model(model_descr) tok = pipe.get_tokenizer() diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index 2176373370..d652c8b48f 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -140,6 +140,14 @@ def get_argprser(): parser.add_argument('--lora_alphas', nargs='*', help='Alphas params for LoRA adapters.', required=False, default=[]) parser.add_argument("--use_cb", action="store_true", help="Use Continuous Batching inference mode") parser.add_argument("--cb_config", required=False, default=None, help="Path to file with Continuous Batching Scheduler settings or dict") + parser.add_argument("--draft_model", required=False, default=None, + help="Path to draft model folder including IR files for Speculative decoding generation") + parser.add_argument("--draft_device", required=False, default=None, help="Inference device for Speculative decoding of draft model") + parser.add_argument("--draft_cb_config", required=False, default=None, + help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model") + parser.add_argument("--num_assistant_tokens", required=False, default=None, help="Config option num_assistant_tokens for Speculative decoding") + parser.add_argument("--assistant_confidence_threshold", required=False, default=None, + help="Config option assistant_confidence_threshold for Speculative decoding") parser.add_argument( '--end_token_stopping', action='store_true', diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 64313e13ab..6539bef232 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -135,12 +135,20 @@ def analyze_args(args): model_args['model_type'] = get_model_type(model_name, use_case, model_framework) model_args['model_name'] = model_name - if args.use_cb and not args.genai: + if (args.use_cb or args.draft_model) and not args.genai: raise RuntimeError("Continuous batching mode supported only via OpenVINO GenAI") cb_config = None if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config + model_args['draft_model'] = args.draft_model + model_args['draft_device'] = args.draft_device + draft_cb_config = None + if args.draft_cb_config: + draft_cb_config = get_config(args.draft_cb_config) + model_args["draft_cb_config"] = draft_cb_config + model_args['num_assistant_tokens'] = args.num_assistant_tokens + model_args['assistant_confidence_threshold'] = args.assistant_confidence_threshold return model_path, model_framework, model_args, model_name diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index f5d4452e30..023d4864a8 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -204,6 +204,21 @@ def create_text_gen_model(model_path, device, **kwargs): return ov_model, tokenizer, from_pretrained_time, bench_hook, False +def get_scheduler_config_genai(user_config, config_name="CB config"): + import openvino_genai + + default_cb_config = {"cache_size": 1} + scheduler_config = openvino_genai.SchedulerConfig() + scheduler_params = user_config or default_cb_config + if scheduler_params: + log.info(f"Scheduler parameters for {config_name}:\n{scheduler_params}") + + for param, value in scheduler_params.items(): + setattr(scheduler_config, param, value) + + return scheduler_config + + def create_genai_text_gen_model(model_path, device, ov_config, **kwargs): import openvino_tokenizers # noqa: F401 import openvino_genai @@ -214,18 +229,20 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs): tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + draft_model_path = kwargs.get("draft_model", '') cb = kwargs.get("use_cb", False) - if cb: + if cb or draft_model_path: log.info("Continuous Batching mode activated") - default_cb_config = {"cache_size": 1} - scheduler_config = openvino_genai.SchedulerConfig() - scheduler_params = kwargs.get("cb_config") or default_cb_config - if scheduler_params: - log.info(f"Scheduler parameters:\n{scheduler_params}") + ov_config["scheduler_config"] = get_scheduler_config_genai(kwargs.get("cb_config")) - for param, value in scheduler_params.items(): - setattr(scheduler_config, param, value) - ov_config["scheduler_config"] = scheduler_config + if draft_model_path: + if not Path(draft_model_path).exists(): + raise RuntimeError(f'==Failure ==: draft model by path:{draft_model_path} is not exists') + log.info("Speculative Decoding is activated") + draft_device = kwargs.get('draft_device', None) or device + draft_model_load_kwargs = {'scheduler_config': get_scheduler_config_genai(kwargs.get("draft_cb_config"), "draft CB config")}\ + if kwargs.get("draft_cb_config") is not None else {} + ov_config['draft_model'] = openvino_genai.draft_model(draft_model_path, draft_device.upper(), **draft_model_load_kwargs) adapter_config = get_lora_config(kwargs.get("lora", None), kwargs.get("lora_alphas", [])) if adapter_config: @@ -263,7 +280,7 @@ def get_tokens(self): def get_time_list(self): return self.token_generation_time - streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb else None + streamer = TokenStreamer(llm_pipe.get_tokenizer()) if cb or draft_model_path else None return llm_pipe, tokenizer, end - start, streamer, True diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index d936721344..029bcdf16d 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -194,12 +194,24 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data if args['infer_count'] is not None: out_str += 'all max_output_token_size: {} * {}'.format(args['infer_count'], args['batch_size']) log.info(out_str) + gen_config = model.get_generation_config() + gen_config.max_new_tokens = max_gen_tokens + gen_config.num_beams = args["num_beams"] + gen_config.do_sample = False + if args.get('draft_model', ''): + config_info = "Speculative decoding config: " + if args.get('num_assistant_tokens', None): + gen_config.num_assistant_tokens = args['num_assistant_tokens'] + config_info += f" num_assistant_tokens {gen_config.num_assistant_tokens}" + if args.get('assistant_confidence_threshold', None): + gen_config.assistant_confidence_threshold = args['assistant_confidence_threshold'] + config_info += f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}" + log.info(config_info) start = time.perf_counter() - generation_result = model.generate(input_text_list, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], do_sample=False) + generation_result = model.generate(input_text_list, gen_config) end = time.perf_counter() generated_text = generation_result.texts perf_metrics = generation_result.perf_metrics - if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2: mem_consumption.end_collect_momory_consumption() max_rss_mem_consumption, max_shared_mem_consumption, max_uss_mem_consumption = mem_consumption.get_max_memory_consumption() @@ -314,8 +326,21 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg mem_consumption.start_collect_memory_consumption() max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count'] streamer.reset() + gen_config = model.get_generation_config() + gen_config.max_new_tokens = max_gen_tokens + gen_config.num_beams = args["num_beams"] + gen_config.do_sample = False + if args.get('draft_model', ''): + config_info = "Speculative decoding config: " + if args.get("num_assistant_tokens", None): + gen_config.num_assistant_tokens = int(args["num_assistant_tokens"]) + config_info += f'num_assistant_tokens {args["num_assistant_tokens"]}' + if args.get("assistant_confidence_threshold", None): + gen_config.assistant_confidence_threshold = float(args["assistant_confidence_threshold"]) + config_info += f'assistant_confidence_threshold {args["assistant_confidence_threshold"]}' + log.info(config_info) start = time.perf_counter() - generated_tokens = model.generate(input_data, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], streamer=streamer, do_sample=False).tokens + generated_tokens = model.generate(input_data, gen_config, streamer=streamer).tokens end = time.perf_counter() if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2: mem_consumption.end_collect_momory_consumption()