Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chat conversation for StaticLLMPipeline #580

Merged
2 changes: 1 addition & 1 deletion samples/cpp/chat_sample/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ int main(int argc, char* argv[]) try {
std::string prompt;
std::string model_path = argv[1];

std::string device = "CPU"; // GPU can be used as well
std::string device = "CPU"; // GPU, NPU can be used as well
ov::genai::LLMPipeline pipe(model_path, "CPU");

ov::genai::GenerationConfig config;
Expand Down
59 changes: 42 additions & 17 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,15 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
model->reshape(new_shapes);
}

void fill_tensor(ov::Tensor tensor, int64_t fill_val) {
void fill_tensor(ov::Tensor tensor, int64_t fill_val, size_t offset = 0u) {
int64_t* tensor_data = tensor.data<int64_t>();
std::fill(tensor_data, tensor_data + tensor.get_size(), fill_val);
std::fill(tensor_data + offset, tensor_data + tensor.get_size(), fill_val);
}

void copy_with_left_offset(const ov::Tensor& orig, ov::Tensor& padded) {
const auto orig_size = orig.get_size();
const auto padded_size = padded.get_size();
const auto kLeftOffset = padded_size - orig_size;
void copy_with_offset(const ov::Tensor& orig, const int32_t offset, ov::Tensor& padded) {
int64_t* orig_data = orig.data<int64_t>();
int64_t* padded_data = padded.data<int64_t>();
std::copy(orig_data, orig_data + orig_size, padded_data + kLeftOffset);
std::copy(orig_data, orig_data + orig.get_size(), padded_data + offset);
}

ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string& config_name) {
Expand All @@ -111,7 +108,7 @@ ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string
{ "NPUW_FOLD", "YES" },
{ "NPUW_DCOFF_TYPE", "f16" },
{ "NPUW_DCOFF_SCALE", "YES" },
{ "NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add_RMSNorm" },
{ "NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add" },
{ "NPUW_PARALLEL_COMPILE", "YES" },
{ "NPUW_FUNCALL_ASYNC", "YES" }
};
Expand Down Expand Up @@ -179,6 +176,18 @@ StaticLLMPipeline::StaticLLMPipeline(
) : StaticLLMPipeline(path, path.string(), device, config) {
}

void StaticLLMPipeline::start_chat(const std::string& system_message) {
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void StaticLLMPipeline::finish_chat() {
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
m_is_chat_conversation = false;
m_history.clear();
};

void StaticLLMPipeline::prepare_for_new_conversation() {
fill_tensor(m_prefill_request.get_tensor("input_ids"), m_tokenizer.get_pad_token_id());
fill_tensor(m_prefill_request.get_tensor("position_ids"), 0u);
Expand All @@ -198,9 +207,23 @@ DecodedResults StaticLLMPipeline::generate(
}

OPENVINO_ASSERT(std::holds_alternative<std::string>(inputs));
auto tokenized_input = m_tokenizer.encode(std::get<std::string>(inputs));
auto& prompt = std::get<std::string>(inputs);

if (m_is_chat_conversation) {
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
prompt = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
}

auto tokenized_input = m_tokenizer.encode(prompt);
auto encoded_results = generate(tokenized_input, config, streamer);
return {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores};
DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores};
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved

if (m_is_chat_conversation) {
auto answer = decoded_results.texts[0];
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
return decoded_results;
}

EncodedResults StaticLLMPipeline::generate(
Expand Down Expand Up @@ -245,22 +268,25 @@ EncodedResults StaticLLMPipeline::generate(
ov::genai::EncodedResults results;
// NB: Only batch=1 is supported now
results.scores.resize(1u);
results.scores[0] = 0u;
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
results.tokens.resize(1u);

// NB: Check if input prompt less than maximum size
// NB: Check if there is enough space in KV-cache to process input prompt
auto prompt_len = input_ids.get_size();
if (prompt_len > m_kvcache_desc.total_size) {
OPENVINO_THROW("Currently static pipeline only process up to " + std::to_string(m_kvcache_desc.total_size) + " tokens");
}

// NB: Reset tensors on every generate call - chat conversation isn't supported yet!
// NB: From the "generate" perspective, every call is treated as start of new conversation,
// but if continuation is needed, prompt contains information about the entire conversation.
prepare_for_new_conversation();

auto padded_input_ids = m_prefill_request.get_tensor("input_ids");
copy_with_left_offset(input_ids, padded_input_ids);
const size_t offset = padded_input_ids.get_size() - input_ids.get_size();
copy_with_offset(input_ids, offset, padded_input_ids);

auto padded_attention_mask = m_prefill_request.get_tensor("attention_mask");
copy_with_left_offset(attention_mask, padded_attention_mask);
fill_tensor(padded_attention_mask, 1u, offset);

auto padded_position_ids = m_prefill_request.get_tensor("position_ids");
auto* padded_pos_data = padded_position_ids.data<int64_t>();
Expand All @@ -271,13 +297,13 @@ EncodedResults StaticLLMPipeline::generate(
// NB: Now there are prompt_len tokens in KV-cache
m_kvcache_desc.num_stored_tokens += prompt_len;
int64_t last_token = utils::argmax(m_prefill_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}

padded_attention_mask.copy_to(m_kvcache_request.get_tensor("attention_mask"));


// Inputs: input_ids, attention_mask, position_ids, ...
// Outputs: logits, ...
const auto kStartInputKVCacheLayers = 3u;
Expand Down Expand Up @@ -309,13 +335,12 @@ EncodedResults StaticLLMPipeline::generate(

last_token = utils::argmax(m_kvcache_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
results.scores[0] = 0u;
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved

if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (last_token == m_generation_config.eos_token_id) {
if (last_token == config.eos_token_id && !config.ignore_eos) {
break;
}

Expand Down
12 changes: 5 additions & 7 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,8 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
StreamerVariant streamer
) override;

void start_chat(const std::string& system_message) override {
OPENVINO_THROW("Currently chat conversation mode isn't supported");
};
void finish_chat() override {
OPENVINO_THROW("Currently chat conversation mode isn't supported");
};

void start_chat(const std::string& system_message) override;
void finish_chat() override;
private:
void prepare_for_new_conversation();

Expand All @@ -54,6 +49,9 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
KVCacheDesc m_kvcache_desc;
ov::InferRequest m_kvcache_request;
ov::InferRequest m_prefill_request;

bool m_is_chat_conversation = false;
ChatHistory m_history;
};

} // namespace genai
Expand Down
Loading