From cd3d2c2233d98b0783de945da2ffe0cbe778fbe1 Mon Sep 17 00:00:00 2001 From: TolyaTalamanov Date: Fri, 26 Jul 2024 14:56:43 +0100 Subject: [PATCH] LLMStaticPipeline hotfix --- src/cpp/src/llm_pipeline_static.cpp | 18 +++++++++--------- src/cpp/src/llm_pipeline_static.hpp | 4 ++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 3f50d30ec9..351e10b523 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -144,26 +144,26 @@ StaticLLMPipeline::StaticLLMPipeline( */ ov::Core core; // (1) Read the template model - this will be kvcache model - auto kvcache_model = core.read_model(path / "openvino_model.xml"); + m_kvcache_model = core.read_model(path / "openvino_model.xml"); // (2) Expose KV-cache input and output layers from kvcache model - ov::pass::StatefulToStateless().run_on_model(kvcache_model); + ov::pass::StatefulToStateless().run_on_model(m_kvcache_model); // (3) Clone the model - this will be prefill - auto prefill_model = kvcache_model->clone(); - prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill"); + m_prefill_model = m_kvcache_model->clone(); + m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill"); // (4) Reshape both models to static shape m_kvcache_desc = KVCacheDesc { 1024u, 0u }; const uint32_t max_prompt_size = m_kvcache_desc.total_size; const uint32_t max_kvcache_size = m_kvcache_desc.total_size; - reshape_to_static(prefill_model, max_prompt_size, max_kvcache_size); - reshape_to_static(kvcache_model, 1u, max_kvcache_size); + reshape_to_static(m_prefill_model, max_prompt_size, max_kvcache_size); + reshape_to_static(m_kvcache_model, 1u, max_kvcache_size); // (5) Add slices to kvcache model - kvcache_model = add_slices_to_kvcache_inputs(kvcache_model); + m_kvcache_model = add_slices_to_kvcache_inputs(m_kvcache_model); // (6) Compile both model m_prefill_request = core.compile_model( - prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG") + m_prefill_model, device, extract_config_or_default(config, "PREFILL_CONFIG") ).create_infer_request(); m_kvcache_request = core.compile_model( - kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG") + m_kvcache_model, device, extract_config_or_default(config, "GENERATE_CONFIG") ).create_infer_request(); // (7) Initialize tensors prepare_for_new_conversation(); diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index 85488e1880..7560b7e336 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -46,6 +46,10 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { uint32_t num_stored_tokens; }; + // FIXME: Ideally, we don't need to keep those + std::shared_ptr m_kvcache_model; + std::shared_ptr m_prefill_model; + KVCacheDesc m_kvcache_desc; ov::InferRequest m_kvcache_request; ov::InferRequest m_prefill_request;