From e309c6a47fc3334a9aa4c86a57d29127b242ef85 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 25 Sep 2024 16:11:58 +0100 Subject: [PATCH] `tool-call`: integrate minja & tool-call to server when --jinja is set --- common/arg.cpp | 12 +- common/common.cpp | 26 +- common/common.h | 23 +- examples/server/server.cpp | 4 +- examples/server/tests/features/steps/steps.py | 43 ++- examples/server/utils.hpp | 146 +++++++-- include/llama.h | 15 +- src/CMakeLists.txt | 2 +- src/llama.cpp | 110 ++++++- tests/test-chat-template.cpp | 296 +++++++++++------- 10 files changed, 514 insertions(+), 163 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index c1ec3c4f99c37..f0d236fd38ad3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1844,13 +1844,21 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(llama_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](gpt_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--chat-template"}, "JINJA_TEMPLATE", "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" - "only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", [](gpt_params & params, const std::string & value) { - if (!llama_chat_verify_template(value)) { + if (!llama_chat_verify_template(value, params.use_jinja)) { throw std::runtime_error(format( "error: the supplied chat template is not supported: %s\n" "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95a737..bcf49f186acc8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1510,16 +1510,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector // Chat template utils // -bool llama_chat_verify_template(const std::string & tmpl) { +bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0, use_jinja); return res >= 0; } std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & msgs, - bool add_ass) { + bool add_ass, + bool use_jinja, + const std::string & tools, + const char * bos_token, + const char * eos_token) { int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; @@ -1532,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token); // error: chat template is not supported if (res < 0) { @@ -1542,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, throw std::runtime_error("this custom template is not supported"); } else { // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token); fallback = true; } } @@ -1553,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template( fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token); } std::string formatted_chat(buf.data(), res); @@ -1564,9 +1568,13 @@ std::string llama_chat_format_single(const struct llama_model * model, const std::string & tmpl, const std::vector & past_msg, const llama_chat_msg & new_msg, - bool add_ass) { + bool add_ass, + bool use_jinja, + const std::string & tools, + const char * bos_token, + const char * eos_token) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1574,7 +1582,7 @@ std::string llama_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index 1a5cfe7b1173b..a42c675cc5b86 100644 --- a/common/common.h +++ b/common/common.h @@ -285,6 +285,7 @@ struct gpt_params { std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT std::string system_prompt = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; std::vector api_keys; @@ -469,14 +470,20 @@ std::string llama_detokenize( // Chat template utils // -// same with llama_chat_message, but uses std::string +// same as llama_chat_message, but uses std::string and std::vector struct llama_chat_msg { std::string role; std::string content; + std::string tool; + struct llama_tool_call { + std::string name; + std::string arguments; + }; + std::vector tool_calls; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool llama_chat_verify_template(const std::string & tmpl); +bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false); // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml @@ -484,14 +491,22 @@ bool llama_chat_verify_template(const std::string & tmpl); std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & chat, - bool add_ass); + bool add_ass, + bool use_jinja = false, + const std::string & tools = "", + const char * bos_token = nullptr, + const char * eos_token = nullptr); // Format single message, while taking into account the position of that message in chat history std::string llama_chat_format_single(const struct llama_model * model, const std::string & tmpl, const std::vector & past_msg, const llama_chat_msg & new_msg, - bool add_ass); + bool add_ass, + bool use_jinja = false, + const std::string & tools = "", + const char * bos_token = nullptr, + const char * eos_token = nullptr); // Returns an example of formatted chat std::string llama_chat_format_example(const struct llama_model * model, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9ac064748ead0..71ffc97cfd6ff 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2781,6 +2781,8 @@ int main(int argc, char ** argv) { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params.n_parallel }, + { "bos_token", llama_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) }, + { "eos_token", llama_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) }, { "chat_template", curr_tmpl.c_str() }, }; @@ -2854,7 +2856,7 @@ int main(int argc, char ** argv) { return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja); std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); ctx_server.queue_results.add_waiting_tasks(tasks); diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 0fea0fe87b799..43241b26ca29f 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -75,6 +75,8 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.server_seed = None context.user_api_key = None context.response_format = None + context.tools = None + context.tool_choice = None context.temperature = None context.lora_file = None context.disable_ctx_shift = False @@ -363,6 +365,13 @@ def step_max_tokens(context, max_tokens): def step_response_format(context, response_format): context.response_format = json.loads(response_format) +@step('tools {tools}') +def step_tools(context, tools): + context.tools = json.loads(tools) + +@step('tool choice {tool_choice}') +def step_tool_choice(context, tool_choice): + context.tool_choice = tool_choice @step('{temperature:f} temperature') def step_temperature(context, temperature): @@ -497,6 +506,11 @@ async def step_oai_chat_completions(context, api_error): response_format=context.response_format if hasattr(context, 'response_format') else None, + tools=context.tools + if hasattr(context, 'tools') else None, + + tool_choice=context.tool_choice, + user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, @@ -567,6 +581,9 @@ async def step_oai_chat_completions(context): if hasattr(context, 'enable_streaming') else None, response_format=context.response_format if hasattr(context, 'response_format') else None, + tools=context.tools + if hasattr(context, 'tools') else None, + tool_choice=context.tool_choice, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -580,16 +597,18 @@ async def step_oai_chat_completions(context): context.base_url, '/chat/completions', True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, + model=context.model, + # if hasattr(context, 'model') else None, + n_predict=context.n_predict, + # if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) + response_format=context.response_format, + # if hasattr(context, 'response_format') else None, + tools=context.tools,# if hasattr(context, 'tools') else None, + tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None, + user_api_key=context.user_api_key) + # if hasattr(context, 'user_api_key') else None) @step('all prompts are predicted') @@ -974,6 +993,8 @@ async def oai_chat_completions(user_prompt, n_predict=None, enable_streaming=None, response_format=None, + tools=None, + tool_choice=None, user_api_key=None, expect_api_error=None) -> int | dict[str, Any]: if debug: @@ -1001,6 +1022,10 @@ async def oai_chat_completions(user_prompt, } if response_format is not None: payload['response_format'] = response_format + if tools is not None: + payload['tools'] = tools + if tool_choice is not None: + payload['tool_choice'] = tool_choice completion_response = { 'content': '', 'timings': { @@ -1065,6 +1090,8 @@ async def oai_chat_completions(user_prompt, max_tokens=n_predict, stream=enable_streaming, response_format=payload.get('response_format') or openai.NOT_GIVEN, + tools=payload.get('tools'), + tool_choice=payload.get('tool_choice'), seed=seed, temperature=payload['temperature'] ) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8cab665014f8c..a80a1b5dde155 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -15,6 +15,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "minja.hpp" +#include "tool-call.h" #include #include @@ -56,22 +58,23 @@ static T json_value(const json & body, const std::string & key, const T & defaul // // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages, const json & tools, bool use_jinja) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - std::string role = json_value(curr_msg, "role", std::string("")); + llama_chat_msg msg; + msg.role = json_value(curr_msg, "role", std::string("")); + msg.tool = json_value(curr_msg, "tool", std::string("")); - std::string content; if (curr_msg.contains("content")) { if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); + msg.content = curr_msg["content"].get(); } else if (curr_msg["content"].is_array()) { for (const auto & part : curr_msg["content"]) { if (part.contains("text")) { - content += "\n" + part["text"].get(); + msg.content += "\n" + part["text"].get(); } } } else { @@ -80,11 +83,21 @@ inline std::string format_chat(const struct llama_model * model, const std::stri } else { throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - - chat.push_back({role, content}); + if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) { + for (const auto & tool_call : curr_msg["tool_calls"]) { + if (json_value(tool_call, "type", std::string("")) == "function" + && tool_call.contains("function") && tool_call["function"].is_object()) { + msg.tool_calls.push_back({ + json_value(tool_call["function"], "name", std::string("")), + json_value(tool_call["function"], "arguments", std::string("")) + }); + } + } + } + chat.emplace_back(std::move(msg)); } - const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); + const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? "" : tools.dump()); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -302,16 +315,56 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons // OAI utils // +static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; +} + +std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) { + int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + return std::string(curr_tmpl_buf.data(), tlen); + } + } + return ""; +} + static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template) { + const std::string & chat_template_src, + bool use_jinja) { json llama_params; llama_params["__oaicompat"] = true; + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src; + llama_params["chat_template"] = chat_template; + if (use_jinja) { + if (has_tools && chat_template.find("tools") == std::string::npos) { + throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); + } + } else if (has_tools) { + throw std::runtime_error("Tools are only supported in --jinja mode"); + } + llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja); // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -320,20 +373,54 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } - // Handle "response_format" field + // Handle "response_format" field (https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format) + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { + // Legacy llama.cpp, llama-cpp-python and Together.ai format. llama_params["json_schema"] = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { - json json_schema = json_value(response_format, "json_schema", json::object()); - llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + // OpenAI JSON schema format. + auto json_schema = json_value(response_format, "json_schema", json::object()); + json schema = json_value(json_schema, "schema", json::object()); + std::string description = json_value(json_schema, "description", std::string()); + if (!description.empty()) { + if (schema.contains("description")) { + throw std::runtime_error("Cannot have both a description in the json_schema object and inside its schema."); + } + schema["description"] = description; + } + bool strict = json_value(json_schema, "strict", false); + if (strict) { + llama_params["json_schema"] = schema; + } } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } - } + } else if (use_jinja && tool_choice != "none" && has_tools) { + bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + bool allow_content = tool_choice != "required"; + auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools); + + for (const auto & stop : handler.additional_stop_words) { + llama_params["stop"].push_back(stop); + } + if (!handler.grammar_trigger_words.empty()) { + auto triggers = json::array(); + for (const auto & word : handler.grammar_trigger_words) { + triggers.push_back(word); + } + llama_params["grammar_trigger_words"] = triggers; + } + + llama_params["grammar"] = handler.grammar; + llama_params["parse_tool_calls"] = true; + llama_params["parallel_tool_calls"] = parallel_tool_calls; + } + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { @@ -349,10 +436,12 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; - for (const auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); + if (!use_jinja) { + static const std::vector unsupported_params { "tools", "tool_choice" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } } } @@ -380,6 +469,24 @@ static json format_final_response_oaicompat(const json & request, const json & r if (stopped_word || stopped_eos) { finish_reason = "stop"; } + auto chat_template = json_value(request, "chat_template", std::string()); + llama_tool_calls parsed_tool_calls; + auto tools = json_value(request, "tools", json::array()); + json tool_calls; + json message_content; + if (json_value(request, "parse_tool_calls", false) + && !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) { + finish_reason = "tool"; + if (!parsed_tool_calls.content.empty()) { + message_content = parsed_tool_calls.content; + } + tool_calls = json::array(); + for (const auto & tc : parsed_tool_calls.tool_calls) { + tool_calls.push_back({{"name", tc.name}, {"arguments", tc.arguments}}); + } + } else { + message_content = content; + } json choices = streaming ? json::array({json{{"finish_reason", finish_reason}, @@ -387,7 +494,8 @@ static json format_final_response_oaicompat(const json & request, const json & r {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", content}, + {"message", json{{"content", message_content}, + {"tool_calls", tool_calls}, {"role", "assistant"}}}}}); std::time_t t = std::time(0); diff --git a/include/llama.h b/include/llama.h index 132937a0700e7..e3d7b7c6bd7d5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -380,6 +380,13 @@ extern "C" { typedef struct llama_chat_message { const char * role; const char * content; + const char * tool; + struct llama_tool_call { + const char * name; + const char * arguments; + }; + const llama_tool_call * tool_calls; + uint32_t n_tool_calls; } llama_chat_message; // lora adapter @@ -976,7 +983,11 @@ extern "C" { size_t n_msg, bool add_ass, char * buf, - int32_t length); + int32_t length, + bool use_jinja = false, + const char * tools = nullptr, + const char * bos_token = nullptr, + const char * eos_token = nullptr); // // Sampling API @@ -1024,6 +1035,7 @@ extern "C" { struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL + void (*accept_str)( struct llama_sampler * smpl, const char * text); // can be NULL void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required void (*reset) ( struct llama_sampler * smpl); // can be NULL struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL @@ -1041,6 +1053,7 @@ extern "C" { // mirror of llama_sampler_i: LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_accept_str( struct llama_sampler * smpl, const char * piece); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad56202f7..04a5640127b5c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,7 +22,7 @@ add_library(llama unicode-data.cpp ) -target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PUBLIC . ../include ../common) target_compile_features (llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/src/llama.cpp b/src/llama.cpp index a718de054f934..424bae69cfbf1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2,6 +2,8 @@ #include "llama-vocab.h" #include "llama-sampling.h" +#include "minja.hpp" + #include "unicode.h" #include "ggml.h" @@ -20976,7 +20978,95 @@ int32_t llama_detokenize( static int32_t llama_chat_apply_template_internal( const std::string & tmpl, const std::vector & chat, - std::string & dest, bool add_ass) { + std::string & dest, bool add_ass, + bool use_jinja, + const std::string & tools, + const std::string & bos_token, const std::string & eos_token) { + + if (use_jinja) { + auto system_not_supported = tmpl.find("System role not supported") != std::string::npos; + + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos; + + auto messages = json::array(); + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + messages.push_back({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (const auto * msg : chat) { + std::string role(msg->role); + std::string content(msg->content); + if (system_not_supported) { + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + content = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + auto message = json({ + {"role", role}, + {"content", content}, + }); + if (msg->tool) message["tool"] = msg->tool; + if (msg->n_tool_calls) { + auto tool_calls = json::array(); + for (uint32_t i = 0; i < msg->n_tool_calls; i++) { + auto args = msg->tool_calls[i].arguments; + tool_calls.push_back(json({ + {"type", "function"}, + {"function", { + {"name", msg->tool_calls[i].name}, + {"arguments", tool_call_args_must_be_objects ? json::parse(args) : args}, + }} + })); + } + messages["tool_calls"] = tool_calls; + } + messages.push_back(message); + } + flush_sys(); + + auto context = minja::Context::make(json({ + {"messages", messages}, + {"add_generation_prompt", add_ass}, + {"bos_token", bos_token}, + {"eos_token", eos_token}, + })); + if (!tools.empty()) { + auto tools_val = minja::Value(json::parse(tools)); + context->set("tools", tools_val); + } + auto tmpl_root = minja::Parser::parse(tmpl, { + .lstrip_blocks = true, + .trim_blocks = true, + }); + try { + dest = tmpl_root->render(context); + return dest.size(); + } catch (const std::runtime_error & err) { + LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what()); + return -1; + } + } + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; auto tmpl_contains = [&tmpl](std::string haystack) -> bool { @@ -21243,7 +21333,11 @@ int32_t llama_chat_apply_template( size_t n_msg, bool add_ass, char * buf, - int32_t length) { + int32_t length, + bool use_jinja, + const char * tools, + const char * bos_token, + const char * eos_token) { std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { GGML_ASSERT(model != nullptr); @@ -21258,6 +21352,16 @@ int32_t llama_chat_apply_template( curr_tmpl = std::string(model_template.data(), model_template.size()); } } + std::string curr_bos_token(bos_token ? bos_token : ""); + std::string curr_eos_token(eos_token ? eos_token : ""); + if (bos_token == nullptr) { + GGML_ASSERT(model != nullptr); + curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true); + } + if (eos_token == nullptr) { + GGML_ASSERT(model != nullptr); + curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true); + } // format the chat to string std::vector chat_vec; @@ -21267,7 +21371,7 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token); if (res < 0) { return res; } diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index a8222caeefb88..114ce592846a4 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -9,7 +9,16 @@ #include "common.h" int main(void) { - llama_chat_message conversation[] = { + struct test_template { + std::string name; + std::string tmpl; + std::string bos; + std::string eos; + std::string expected_output; + std::string jinja_expected_output; + }; + + std::vector conversation { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, @@ -17,134 +26,191 @@ int main(void) { {"assistant", " I am an assistant "}, {"user", "Another question"}, }; - size_t message_count = 6; - std::vector templates = { - // teknium/OpenHermes-2.5-Mistral-7B - "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - // mistralai/Mistral-7B-Instruct-v0.2 - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", - // TheBloke/FusionNet_34Bx2_MoE-AWQ - "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", - // bofenghuang/vigogne-2-70b-chat - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - // mlabonne/AlphaMonarch-7B - "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - // google/gemma-7b-it - "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", - // OrionStarAI/Orion-14B-Chat - "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", - // openchat/openchat-3.5-0106 - // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d - // So we match against the included template but implement the suggested version. - "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", - // deepseek-ai/deepseek-coder-33b-instruct - "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", - // eachadea/vicuna-13b-1.1 - // No template included in tokenizer_config.json, so this template likely needs to be manually set. - "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - // Orca-Vicuna - // No template included in tokenizer_config.json, so this template likely needs to be manually set. - "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - // CohereForAI/c4ai-command-r-plus - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - // Llama-3 - "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", - //Phi-3-mini - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - //Phi-3-small - "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - //Phi-3-medium - "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - //Phi-3-vision - "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", - // ChatGLM3 - "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - // ChatGLM4 - u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - // MiniCPM-3B-OpenHermes-2.5-v2-GGUF - u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", - // DeepSeek-V2 - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", - }; - std::vector expected_output = { - // teknium/OpenHermes-2.5-Mistral-7B - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", - // mistralai/Mistral-7B-Instruct-v0.2 - "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // TheBloke/FusionNet_34Bx2_MoE-AWQ - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // bofenghuang/vigogne-2-70b-chat - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // mlabonne/AlphaMonarch-7B - "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - // google/gemma-7b-it - "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", - // OrionStarAI/Orion-14B-Chat - "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - // openchat/openchat-3.5-0106 - "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", - // deepseek-ai/deepseek-coder-33b-instruct - "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", - // eachadea/vicuna-13b-1.1 - "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", - // Orca-Vicuna - "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", - // CohereForAI/c4ai-command-r-plus - "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - // Llama 3 - "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - //Phi-3-mini - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-small - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-medium - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-vision - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - // ChatGLM3 - "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", - // ChatGLM4 - "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", - // MiniCPM-3B-OpenHermes-2.5-v2-GGUF - u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", - // DeepSeek-V2 - u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + + std::vector templates { + { + .name = "teknium/OpenHermes-2.5-Mistral-7B", + .tmpl = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + .expected_output = "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + .bos = "<|im_start|>", + .eos = "<|im_end|>", + }, + { + .name = "mistralai/Mistral-7B-Instruct-v0.2", + .tmpl = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + .expected_output = "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + .bos = "<|startoftext|>", + .eos = "<|endoftext|>", + }, + { + .name = "TheBloke/FusionNet_34Bx2_MoE-AWQ", + .tmpl = "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + .expected_output = "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + .bos = "", + .eos = "", + }, + { + .name = "bofenghuang/vigogne-2-70b-chat", + .tmpl = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + .expected_output = "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + .bos = "", + .eos = "", + }, + { + .name = "mlabonne/AlphaMonarch-7B", + .tmpl = "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", + .expected_output = "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + .jinja_expected_output = "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + .bos = "", + .eos = "", + }, + { + .name = "google/gemma-7b-it", + .tmpl = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", + .expected_output = "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + .bos = "", + .eos = "", + }, + { + .name = "OrionStarAI/Orion-14B-Chat", + .tmpl = "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", + .expected_output = "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + .jinja_expected_output = "Human: Hello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + .bos = "", + .eos = "", + }, + { + // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d, + // So we match against the included template but implement the suggested version. + .name = "openchat/openchat-3.5-0106", + .tmpl = "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + .expected_output = "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + .eos = "<|end_of_turn|>", + }, + { + .name = "deepseek-ai/deepseek-coder-33b-instruct", + .tmpl = "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + .expected_output = "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + }, + { + // No template included in tokenizer_config.json, so this template likely needs to be manually set., + .name = "eachadea/vicuna-13b-1.1", + .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", + .expected_output = "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + }, + { + // No template included in tokenizer_config.json, so this template likely needs to be manually set. + .name = "Orca-Vicuna", + .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", + .expected_output = "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + }, + { + .name = "CohereForAI/c4ai-command-r-plus", + .tmpl = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + .expected_output = "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + }, + { + .name = "Llama-3", + .tmpl = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", + .expected_output = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + }, + { + .name = "Phi-3-mini", + .tmpl = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + .expected_output = "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + .name = "Phi-3-small", + .tmpl = "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + .expected_output = "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + .name = "Phi-3-medium", + .tmpl = "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + .expected_output = "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + .name = "Phi-3-vision", + .tmpl = "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", + .expected_output = "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + .name = "ChatGLM3", + .tmpl = "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + .expected_output = "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + }, + { + .name = "ChatGLM4", + .tmpl = u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + .expected_output = "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + }, + { + .name = "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", + .tmpl = u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", + .expected_output = u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + }, + { + .name = "DeepSeek-V2", + .tmpl = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + .expected_output = u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + } }; + std::vector formatted_chat(1024); int32_t res; // test invalid chat template - res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, "<|im_start|>", "<|im_end|>"); assert(res < 0); - for (size_t i = 0; i < templates.size(); i++) { - std::string custom_template = templates[i]; - std::string expected = expected_output[i]; - formatted_chat.resize(1024); - res = llama_chat_apply_template( - nullptr, - custom_template.c_str(), - conversation, - message_count, - true, - formatted_chat.data(), - formatted_chat.size() - ); - formatted_chat.resize(res); - std::string output(formatted_chat.data(), formatted_chat.size()); - printf("%s\n", output.c_str()); - printf("-------------------------\n"); - assert(output == expected); + for (auto use_jinja : std::vector { false, true }) { + printf("\n\n=== Using Jinja: %s ===\n\n", use_jinja ? "true" : "false"); + for (const auto & tmpl : templates) { + printf("=== %s ===\n", tmpl.name.c_str()); + const auto & custom_template = tmpl.tmpl; + const auto & expected = + use_jinja && !tmpl.jinja_expected_output.empty() + ? tmpl.jinja_expected_output + : tmpl.expected_output; + formatted_chat.resize(1024); + res = llama_chat_apply_template( + nullptr, + custom_template.c_str(), + conversation.data(), + conversation.size(), + true, + formatted_chat.data(), + formatted_chat.size(), + use_jinja, + tmpl.bos.c_str(), + tmpl.eos.c_str() + ); + if (res < 0) { + printf("Error: %d\n", res); + continue; + } + formatted_chat.resize(res); + std::string output(formatted_chat.data(), formatted_chat.size()); + if (output != expected) { + printf("# Failure!\n"); + printf("Template: %s\n", custom_template.c_str()); + printf("Expected:\n"); + printf("%s\n", expected.c_str()); + printf("-------------------------\n"); + printf("Actual:\n"); + printf("%s\n", output.c_str()); + // assert(output == expected); + } + } } - // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; llama_chat_msg sys_msg{"system", "You are a helpful assistant"}; auto fmt_sys = [&](std::string tmpl) { - auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); + auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, "<|im_start|>", "<|im_end|>"); printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); printf("-------------------------\n"); return output; @@ -163,7 +229,7 @@ int main(void) { llama_chat_msg new_msg{"user", "How are you"}; auto fmt_single = [&](std::string tmpl) { - auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true); + auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, "<|im_start|>", "<|im_end|>"); printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); printf("-------------------------\n"); return output;