diff --git a/.editorconfig b/.editorconfig index f88f8da67cd78..19eb504346045 100644 --- a/.editorconfig +++ b/.editorconfig @@ -30,3 +30,11 @@ indent_style = tab [examples/cvector-generator/*.txt] trim_trailing_whitespace = unset insert_final_newline = unset + +[{tests/chat/templates/*.jinja,tests/chat/goldens/*.txt}] +indent_style = unset +indent_size = unset +end_of_line = unset +charset = unset +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/common/common.cpp b/common/common.cpp index 7c5b810ecd117..e6254ef3b1aae 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1516,7 +1516,7 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { nullptr, tmpl.c_str(), chat, - 1, + 1, /* add_ass= */ true, /* buffer= */ nullptr, /* length= */ 0, diff --git a/common/common.h b/common/common.h index 1b5683c007837..64a20f6a0786a 100644 --- a/common/common.h +++ b/common/common.h @@ -624,7 +624,7 @@ class llama_antiprompts { f = f->fail; } - child.fail = (f == &root && f->children.find(c) == f->children.end()) + child.fail = (f == &root && f->children.find(c) == f->children.end()) ? &root : &f->children[c]; if (child.fail->output != -1) { @@ -654,7 +654,7 @@ class llama_antiprompts { }, stop_words, grammar_trigger_words - ); + ); } void build(const std::function(const std::string)> & tokenizer, const std::vector & stop_words, const std::vector & grammar_trigger_words) { @@ -708,7 +708,7 @@ class llama_antiprompts { MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { TrieNode* current = &root; MatchResult partialMatch{std::string::npos, "", true, 0, false}; - + for (size_t i = offset; i < text.length(); ++i) { char c = text[i]; while (current != &root && current->children.find(c) == current->children.end()) { @@ -736,12 +736,12 @@ class llama_antiprompts { partialMatch.is_grammar_trigger = false; } } - + // If we've found a partial match and haven't returned a full match, return the partial match if (partialMatch.pos != std::string::npos) { return partialMatch; } - + return {std::string::npos, "", false, 0, false}; } }; diff --git a/common/minja.hpp b/common/minja.hpp index 4a9d32ad1516a..3e0b95d0aaae5 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -48,7 +48,7 @@ class Value : public std::enable_shared_from_this { } return Value(); } - + bool empty() { return args.empty() && kwargs.empty(); } @@ -61,7 +61,7 @@ class Value : public std::enable_shared_from_this { } } }; - + using CallableType = std::function &, Arguments &)>; using FilterType = std::function &, Arguments &)>; @@ -143,7 +143,7 @@ class Value : public std::enable_shared_from_this { } else if (is_boolean()) { out << (this->to_bool() ? "True" : "False"); } else if (is_string()) { - dump_string(primitive_, out, string_quote); + dump_string(primitive_, out, string_quote); } else { out << primitive_.dump(); } @@ -175,7 +175,7 @@ class Value : public std::enable_shared_from_this { primitive_ = v; } } - + std::vector keys() { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); std::vector res; @@ -267,7 +267,7 @@ class Value : public std::enable_shared_from_this { if (is_string()) return !get().empty(); if (is_array()) return !empty(); return true; - } + } bool operator<(const Value & other) const { if (is_null()) @@ -369,7 +369,7 @@ class Value : public std::enable_shared_from_this { if (!contains(key)) return default_value; return at(key).get(); } - + template T get() const { if (is_primitive()) return primitive_.get(); @@ -730,7 +730,7 @@ class TemplateNode { Location location_; protected: virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; - + public: TemplateNode(const Location & location) : location_(location) {} void render(std::ostringstream & out, const std::shared_ptr & context) const { @@ -817,7 +817,7 @@ class ForNode : public TemplateNode { ForNode(const Location & location, std::vector && var_names, std::unique_ptr && iterable, std::unique_ptr && condition, std::unique_ptr && body, bool recursive, std::unique_ptr && else_body) : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} - + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { // https://jinja.palletsprojects.com/en/3.0.x/templates/#for @@ -920,7 +920,7 @@ class MacroNode : public TemplateNode { auto & arg_name = arg.first; auto it = named_param_positions.find(arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - + call_context->set(arg_name, arg.second); param_set[it->second] = true; } @@ -1098,7 +1098,7 @@ class BinaryOpExpr : public Expression { : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { auto l = left->evaluate(context); - + auto do_eval = [&](const Value & l) -> Value { if (op == Op::Is || op == Op::IsNot) { auto t = dynamic_cast(right.get()); @@ -1297,7 +1297,7 @@ class Parser { std::shared_ptr template_str; CharIterator start, end, it; Options options; - + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { if (!template_str) throw std::runtime_error("Template string is null"); start = it = this->template_str->begin(); @@ -1326,7 +1326,7 @@ class Parser { case 'b': result += '\b'; break; case 'f': result += '\f'; break; case '\\': result += '\\'; break; - default: + default: if (*it == quote) { result += quote; } else { @@ -1562,7 +1562,7 @@ class Parser { if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); return nonstd_make_unique( - left->location, + left->location, std::move(left), std::move(identifier), negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); } @@ -1588,7 +1588,7 @@ class Parser { if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); Expression::Parameters result; - + while (it != end) { if (!consumeToken(")").empty()) { return result; @@ -1622,7 +1622,7 @@ class Parser { if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); Expression::Arguments result; - + while (it != end) { if (!consumeToken(")").empty()) { return result; @@ -1655,7 +1655,7 @@ class Parser { static std::regex ident_regex(R"((?!not|is|and|or|del)[a-zA-Z_]\w*)"); auto location = get_location(); auto ident = consumeToken(ident_regex); - if (ident.empty()) + if (ident.empty()) return nullptr; return nonstd_make_unique(location, ident); } @@ -1699,7 +1699,7 @@ class Parser { } return left; } - + std::unique_ptr parseMathMulDiv() { auto left = parseMathUnaryPlusMinus(); if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); @@ -1709,9 +1709,9 @@ class Parser { while (!(op_str = consumeToken(mul_div_tok)).empty()) { auto right = parseMathUnaryPlusMinus(); if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); - auto op = op_str == "*" ? BinaryOpExpr::Op::Mul - : op_str == "**" ? BinaryOpExpr::Op::MulMul - : op_str == "/" ? BinaryOpExpr::Op::Div + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div : op_str == "//" ? BinaryOpExpr::Op::DivDiv : BinaryOpExpr::Op::Mod; left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); @@ -1741,14 +1741,14 @@ class Parser { auto op_str = consumeToken(unary_plus_minus_tok); auto expr = parseValueExpression(); if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus' expression"); - + if (!op_str.empty()) { auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; return nonstd_make_unique(get_location(), std::move(expr), op); } return expr; } - + std::unique_ptr parseValueExpression() { auto parseValue = [&]() -> std::unique_ptr { auto location = get_location(); @@ -1774,7 +1774,7 @@ class Parser { }; auto value = parseValue(); - + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { std::unique_ptr index; @@ -1797,7 +1797,7 @@ class Parser { } if (!index) throw std::runtime_error("Empty index in subscript"); if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - + value = nonstd_make_unique(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); @@ -1825,10 +1825,10 @@ class Parser { std::unique_ptr parseBracedExpressionOrArray() { if (consumeToken("(").empty()) return nullptr; - + auto expr = parseExpression(); if (!expr) throw std::runtime_error("Expected expression in braced expression"); - + if (!consumeToken(")").empty()) { return expr; // Drop the parentheses } @@ -1851,7 +1851,7 @@ class Parser { std::unique_ptr parseArray() { if (consumeToken("[").empty()) return nullptr; - + std::vector> elements; if (!consumeToken("]").empty()) { return nonstd_make_unique(get_location(), std::move(elements)); @@ -1876,7 +1876,7 @@ class Parser { std::unique_ptr parseDictionary() { if (consumeToken("{").empty()) return nullptr; - + std::vector, std::unique_ptr>> elements; if (!consumeToken("}").empty()) { return nonstd_make_unique(get_location(), std::move(elements)); @@ -1892,7 +1892,7 @@ class Parser { }; parseKeyValuePair(); - + while (it != end) { if (!consumeToken(",").empty()) { parseKeyValuePair(); @@ -1950,15 +1950,15 @@ class Parser { static std::regex text_regex(R"([\s\S\n]*?($|(?=\{\{|\{%|\{#)))"); static std::regex expr_close_regex(R"([\s\n]*([-~])?\}\})"); static std::regex block_close_regex(R"([\s\n]*([-~])?%\})"); - + TemplateTokenVector tokens; std::vector group; std::string text; - + try { while (it != end) { auto location = get_location(); - + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { auto pre_space = parsePreSpace(group[1]); auto content = group[2]; @@ -1985,7 +1985,7 @@ class Parser { }; if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); - + if (keyword == "if") { auto condition = parseExpression(); if (!condition) throw std::runtime_error("Expected condition in if block"); @@ -2019,7 +2019,7 @@ class Parser { condition = parseExpression(); } auto recursive = !consumeToken(recursive_tok).empty(); - + auto post_space = parseBlockClose(); tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); } else if (keyword == "endfor") { @@ -2034,7 +2034,7 @@ class Parser { if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { ns = group[1]; var_names.push_back(group[2]); - + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); value = parseExpression(); @@ -2115,7 +2115,7 @@ class Parser { } else if (auto text_token = dynamic_cast(token.get())) { SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; - + auto text = text_token->text; if (pre_space == SpaceHandling::Strip) { static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); @@ -2131,7 +2131,7 @@ class Parser { static std::regex trailing_last_line_space_regex(R"((^|\n)[ \t]*$)"); text = std::regex_replace(text, trailing_last_line_space_regex, "$1"); } - + if (it == end && !options.keep_trailing_newline) { static std::regex r(R"([\n\r]$)"); text = std::regex_replace(text, r, ""); // Strip one trailing newline @@ -2473,7 +2473,7 @@ inline std::shared_ptr Context::builtins() { int64_t start = param_set[0] ? startEndStep[0] : 0; int64_t end = startEndStep[1]; int64_t step = param_set[2] ? startEndStep[2] : 1; - + auto res = Value::array(); if (step > 0) { for (int64_t i = start; i < end; i += step) { diff --git a/common/sampling.cpp b/common/sampling.cpp index ac1f8b174f23b..bbe2f81e6e2c5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -147,7 +147,7 @@ bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * llama_sampler_accept_str(gsmpl->grmr, trigger.c_str()); return true; } - + struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); diff --git a/common/tool-call.cpp b/common/tool-call.cpp index d7e3ba85a37bf..cb9ee2ecf4124 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -84,7 +84,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { std::regex start_pattern(R"([\n\s]*)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); - + auto end = input.end(); std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); @@ -176,7 +176,7 @@ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const std::str it = rit->suffix().first; auto name = rit->str(1); - + json arguments; if (!parse_json(it, end, arguments)) { throw std::runtime_error("Failed to parse json tool call arguments"); @@ -229,7 +229,7 @@ llama_tool_call_handler llama_tool_call_handler_init( const nlohmann::ordered_json & tools) { llama_tool_call_handler handler; - + if (needs_functionary_v3_tool_call(chat_template)) { // MeetKaiFunctionary_3_2 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -312,7 +312,7 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.grammar_trigger_words.push_back("<|python_tag|>"); } } else { - //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + tool_rules.push_back( builder.add_rule( name + "-call", diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cbd8b00355c4d..aea498f967011 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -182,7 +182,7 @@ struct server_slot { std::string stopping_word; llama_antiprompts antiprompts; - + // sampling json json_schema; diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index b0db9953b0597..480b85c23c0c6 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -654,7 +654,7 @@ async def step_tool_called(context, expected_name, expected_arguments): expected_name = expected_name if expected_name else None expected_arguments = json.loads(expected_arguments) if expected_arguments else None - + def check(tool_calls): if tool_calls is None: assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}' @@ -1055,7 +1055,7 @@ async def oai_chat_completions(user_prompt, user_api_key = user_api_key if user_api_key is not None else 'nope' assert isinstance(seed, int), f'seed: {seed}' seed = seed if seed is not None else 42 - + enable_streaming = enable_streaming if enable_streaming is not None else False messages = [] if system_prompt: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b124f07710aef..fff4a78bc5541 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -353,7 +353,7 @@ static json oaicompat_completion_params_parse( auto tools = json_value(body, "tools", json()); auto has_tools = tools.is_array() && !tools.empty(); - + // Apply chat template to the list of messages auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src; llama_params["chat_template"] = chat_template; @@ -420,7 +420,7 @@ static json oaicompat_completion_params_parse( llama_params["parse_tool_calls"] = true; llama_params["parallel_tool_calls"] = parallel_tool_calls; } - + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { diff --git a/tests/chat/contexts/simple.json b/tests/chat/contexts/simple.json index fa4877616dcef..560f92f7300ca 100644 --- a/tests/chat/contexts/simple.json +++ b/tests/chat/contexts/simple.json @@ -12,4 +12,4 @@ "add_generation_prompt": true, "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>" -} \ No newline at end of file +} diff --git a/tests/chat/contexts/system.json b/tests/chat/contexts/system.json index 9c016f36910c6..4d72972add3ee 100644 --- a/tests/chat/contexts/system.json +++ b/tests/chat/contexts/system.json @@ -16,4 +16,4 @@ "add_generation_prompt": true, "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>" -} \ No newline at end of file +} diff --git a/tests/chat/contexts/tool_use.json b/tests/chat/contexts/tool_use.json index 6345ef24b7876..0d037d2f6494d 100644 --- a/tests/chat/contexts/tool_use.json +++ b/tests/chat/contexts/tool_use.json @@ -161,4 +161,4 @@ } } ] -} \ No newline at end of file +} diff --git a/tests/test-antiprompts.cpp b/tests/test-antiprompts.cpp index 226c7d24f4f30..fc09f98eb9d21 100644 --- a/tests/test-antiprompts.cpp +++ b/tests/test-antiprompts.cpp @@ -26,12 +26,12 @@ int main() }; const std::vector stop_words { }; const std::vector grammar_trigger_words { }; - + printf("Testing antiprompts\n"); llama_antiprompts antiprompts; antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"}); - + assert_equal(antiprompts.findSingleTokenMatch('x'), { .pos = 0, .pattern = "x", diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 68fe6c381713a..faa95ceaa29be 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -17,7 +17,7 @@ int main(void) { std::string expected_output; std::string jinja_expected_output; }; - + std::vector conversation { {"system", "You are a helpful assistant"}, {"user", "Hello"}, @@ -100,7 +100,7 @@ int main(void) { .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", .expected_output = "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", }, - { + { // No template included in tokenizer_config.json, so this template likely needs to be manually set. .name = "Orca-Vicuna", .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", @@ -157,7 +157,7 @@ int main(void) { .expected_output = u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", } }; - + std::vector formatted_chat(1024); int32_t res; diff --git a/tests/test-minja.cpp b/tests/test-minja.cpp index ad835e0362e8e..25a8e9e3c69dc 100644 --- a/tests/test-minja.cpp +++ b/tests/test-minja.cpp @@ -1,6 +1,6 @@ /* Minimalistic Jinja templating engine for llama.cpp. C++11, no deps (single-header), decent language support but very few functions (easy to extend), just what’s needed for actual prompt templates. - + Models have increasingly complex templates (e.g. Llama 3.1, Hermes 2 Pro w/ tool_use), so we need a proper template engine to get the best out of them. Supports: @@ -20,7 +20,7 @@ - No tuples (templates seem to rely on lists only) - No `if` expressions w/o `else` (but `if` statements are fine) - No `{% raw %}`, `{% block … %}`, `{% include … %}`, `{% extends … %}, - + Model templates verified to work: - Meta-Llama-3.1-8B-Instruct - Phi-3.5-mini-instruct @@ -160,7 +160,7 @@ static void test_template_features() { test_render(R"({{ {"a": "b"} | tojson }})", {}, {}, R"({"a": "b"})"); test_render(R"({{ {"a": "b"} }})", {}, {}, R"({'a': 'b'})"); - std::string trim_tmpl = + std::string trim_tmpl = "\n" " {% if true %}Hello{% endif %} \n" "...\n" @@ -228,7 +228,7 @@ static void test_template_features() { ({{ i }}, {{ loop.cycle('odd', 'even') }}), {%- endfor -%} )", {}, {}, "(0, odd),(1, even),(2, odd),(3, even),(4, odd),"); - + test_render( "{%- for i in range(5) if i % 2 == 0 -%}\n" "{{ i }}, first={{ loop.first }}, last={{ loop.last }}, index={{ loop.index }}, index0={{ loop.index0 }}, revindex={{ loop.revindex }}, revindex0={{ loop.revindex0 }}, prev={{ loop.previtem }}, next={{ loop.nextitem }},\n" @@ -237,7 +237,7 @@ static void test_template_features() { "0, first=True, last=False, index=1, index0=0, revindex=3, revindex0=2, prev=, next=2,\n" "2, first=False, last=False, index=2, index0=1, revindex=2, revindex0=1, prev=0, next=4,\n" "4, first=False, last=True, index=3, index0=2, revindex=1, revindex0=0, prev=2, next=,\n"); - + test_render( R"( {%- set res = [] -%} @@ -262,7 +262,7 @@ static void test_template_features() { {% macro input(name, value='', type='text', size=20) -%} {%- endmacro -%} - +

{{ input('username') }}

{{ input('password', type='password') }}

)", {}, {}, R"( @@ -314,14 +314,14 @@ static void test_template_features() { {{- x }},{{ y -}}; {%- endfor -%} )", {{"z", json({json({1, 10}), json({2, 20})})}}, {}, "1,10;2,20;"); - + test_render(" a {{ 'b' -}} c ", {}, {}, " a bc "); test_render(" a {{- 'b' }} c ", {}, {}, " ab c "); test_render("a\n{{- 'b' }}\nc", {}, {}, "ab\nc"); test_render("a\n{{ 'b' -}}\nc", {}, {}, "a\nbc"); test_error_contains("{{ raise_exception('hey') }}", {}, {}, "hey"); - + test_render("{{ [] is iterable }}", {}, {}, "True"); test_render("{{ [] is not number }}", {}, {}, "True"); test_render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}, "[1, 2, 3][0, 1][1, 2]"); @@ -343,16 +343,16 @@ static void test_template_features() { test_error_contains("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}, "Unterminated if"); test_render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {}, ""); - + test_render( - "{% set x = [] %}{% set _ = x.append(1) %}{{ x | tojson(indent=2) }}", {}, {}, + "{% set x = [] %}{% set _ = x.append(1) %}{{ x | tojson(indent=2) }}", {}, {}, "[\n 1\n]"); test_render( - "{{ not [] }}", {}, {}, + "{{ not [] }}", {}, {}, "True"); - - test_render("{{ tool.function.name == 'ipython' }}", + + test_render("{{ tool.function.name == 'ipython' }}", json({{"tool", json({ {"function", {{"name", "ipython"}}} })}}), @@ -369,7 +369,7 @@ static void test_template_features() { static void test_chat_templates_with_common_contexts_against_goldens() { auto jinja_template_files = find_files("tests/chat/templates", ".jinja"); auto context_files = find_files("tests/chat/contexts", ".json"); - + auto get_golden_file = [&](const std::string & tmpl_file, const std::string & ctx_file) { auto tmpl_name = filename_without_extension(tmpl_file); auto ctx_name = filename_without_extension(ctx_file); @@ -431,4 +431,4 @@ int main() { } return 0; -} \ No newline at end of file +} diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index fd0eeed01f693..24ef8a589d093 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -58,7 +58,7 @@ int main() { json request = { {"tools", tools} }; - + std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have inside it"; test_parse_tool_call(tools, hermes_2_pro_like_tmpl, "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", @@ -71,7 +71,7 @@ int main() { }).dump()} }} }}); - + std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; test_parse_tool_call(tools, functionary_v3_like_tmpl, ">>>ipython\nprint('Hello, world!')", @@ -84,7 +84,7 @@ int main() { }).dump()} }} }}); - + std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some {...} inside it"; test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", @@ -107,7 +107,7 @@ int main() { }} }, }); - + std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; test_parse_tool_call(tools, llama_3_1_like_tmpl, "<|python_tag|>this could be anything", @@ -145,4 +145,4 @@ int main() { "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); return 0; -} \ No newline at end of file +} diff --git a/tests/update_jinja_goldens.py b/tests/update_jinja_goldens.py index 9c5d1db87b069..fafa6dee0715a 100644 --- a/tests/update_jinja_goldens.py +++ b/tests/update_jinja_goldens.py @@ -8,10 +8,10 @@ # /// ''' Fetches the Jinja2 templates of a few known models and use them to generate prompt goldens for a few predefined chat contexts. - + Examples: python ./tests/update_jinja_goldens.py - + https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py ''' @@ -33,12 +33,12 @@ "Qwen/Qwen2-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", - "Qwen/Qwen2.5-Math-7B-Instruct", + "Qwen/Qwen2.5-Math-7B-Instruct", "microsoft/Phi-3-mini-4k-instruct", "microsoft/Phi-3-small-8k-instruct", "microsoft/Phi-3-medium-4k-instruct", "microsoft/Phi-3.5-mini-instruct", - "indischepartij/MiniCPM-3B-OpenHermes-2.5-v2", + "indischepartij/MiniCPM-3B-OpenHermes-2.5-v2", "teknium/OpenHermes-2.5-Mistral-7B", "TheBloke/FusionNet_34Bx2_MoE-AWQ", "bofenghuang/vigogne-2-70b-chat", @@ -46,18 +46,18 @@ "OrionStarAI/Orion-14B-Chat", "openchat/openchat-3.5-0106", "deepseek-ai/deepseek-coder-33b-instruct", - "abacusai/Fewshot-Metamath-OrcaVicuna-Mistral", + "abacusai/Fewshot-Metamath-OrcaVicuna-Mistral", "CohereForAI/c4ai-command-r-plus", - "THUDM/chatglm3-6b", - "derek33125/project-angel-chatglm4", - "deepseek-ai/DeepSeek-Coder-V2-Instruct", + "THUDM/chatglm3-6b", + "derek33125/project-angel-chatglm4", + "deepseek-ai/DeepSeek-Coder-V2-Instruct", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", "deepseek-ai/DeepSeek-V2.5", - + # Needs debugging: # "eachadea/vicuna-13b-1.1", # "microsoft/Phi-3-vision-instruct", - + # Gated models: "meta-llama/Meta-Llama-3.1-8B-Instruct", "google/gemma-7b-it", @@ -83,9 +83,9 @@ def handle_chat_template(model_id, variant, template_src): print(f'template_file: {template_file}') with open(template_file, 'w') as f: f.write(template_src) - + print(f"- {template_file}", flush=True) - + env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, @@ -99,25 +99,25 @@ def handle_chat_template(model_id, variant, template_src): template_handles_tools = 'tools' in template_src template_hates_the_system = 'System role not supported' in template_src - + template = env.from_string(template_src) - + context_files = glob.glob('tests/chat/contexts/*.json') for context_file in context_files: context_name = context_file.split("/")[-1].replace(".json", "") with open(context_file, 'r') as f: context = json.load(f) - + if not template_handles_tools and 'tools' in context: continue - + if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): continue - + output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt' print(f"- {output_file}", flush=True) try: - output = template.render(**context) + output = template.render(**context) except: # Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message. for message in context["messages"]: @@ -132,27 +132,27 @@ def handle_chat_template(model_id, variant, template_src): with open(output_file, 'w') as f: f.write(output) - + print() def main(): for dir in ['tests/chat/templates', 'tests/chat/goldens']: if not os.path.isdir(dir): os.mkdir(dir) - + for model_id in model_ids: # response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") # response.raise_for_status() # config_str = response.text with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: config_str = f.read() - - try: + + try: config = json.loads(config_str) except json.JSONDecodeError as e: # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json # (Remove extra '}' near the end of the file) - config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) chat_template = config['chat_template'] if isinstance(chat_template, str): @@ -162,4 +162,4 @@ def main(): handle_chat_template(model_id, ct['name'], ct['template']) if __name__ == '__main__': - main() \ No newline at end of file + main()