diff --git a/Makefile b/Makefile index 88234972f81f2..e5e7e62fa8c2a 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,7 @@ TEST_TARGETS = \ tests/test-grammar-integration \ tests/test-grammar-parser \ tests/test-json-schema-to-grammar \ + tests/test-minja \ tests/test-llama-grammar \ tests/test-log \ tests/test-model-load-cancel \ @@ -1573,6 +1574,11 @@ tests/test-antiprompts: tests/test-antiprompts.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-minja: tests/test-minja.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-grad0: tests/test-grad0.cpp \ $(OBJ_GGML) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 042e895add5e2..34c3620c27cde 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -62,6 +62,7 @@ add_library(${TARGET} STATIC json.hpp log.cpp log.h + minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/minja.hpp b/common/minja.hpp new file mode 100644 index 0000000000000..4a9d32ad1516a --- /dev/null +++ b/common/minja.hpp @@ -0,0 +1,2497 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +/* Backport make_unique from C++14. */ +template +typename std::unique_ptr nonstd_make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + struct Arguments { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return p.second; + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } + }; + + using CallableType = std::function &, Arguments &)>; + using FilterType = std::function &, Arguments &)>; + +private: + using ObjectType = nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { + if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, char string_quote = '\'') const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, string_quote); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, string_quote); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean()) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string()) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + + Value(const json & v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value get(const Value& key) { + if (array_) { + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) return Value(); + return it->second; + } + throw std::runtime_error("Value is not an array or object: " + dump()); + } + void set(const Value& key, const Value& value) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr & context, Value::Arguments & args) const { + if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + bool operator<(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value & value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (array_) throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) return array_->at(index); + if (is_object()) return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + template <> + json get() const { + if (is_primitive()) return primitive_; + if (is_null()) return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& item : *object_) { + const auto & key = item.first; + auto json_value = item.second.get(); + if (key.is_string()) { + res[key.get()] = json_value; + } else if (key.is_primitive()) { + res[key.dump()] = json_value; + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json ? '"' : '\''); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) + return to_str() + rhs.to_str(); + else if (is_number_integer() && rhs.is_number_integer()) + return get() + rhs.get(); + else + return get() + rhs.get(); + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^" << "\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; +public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + struct Arguments { + std::vector> args; + std::vector>> kwargs; + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) const { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } + + Value::Arguments evaluate(const std::shared_ptr & context) const { + Value::Arguments vargs; + for (const auto& arg : this->args) { + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& arg : this->kwargs) { + vargs.kwargs.push_back({arg.first, arg.second->evaluate(context)}); + } + return vargs; + } + }; + + using Parameters = std::vector>>; + + Location location; + + Expression(const Location & location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + try { + return do_evaluate(context); + } catch (const std::runtime_error & e) { + std::ostringstream out; + out << e.what(); + if (location.source) out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::unique_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::unique_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::unique_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::unique_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::unique_ptr iterable; + std::unique_ptr condition; + bool recursive; + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::unique_ptr && iter, + std::unique_ptr && c, bool r) + : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::unique_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::unique_ptr && v) + : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +class TemplateNode { + Location location_; +protected: + virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + void render(std::ostringstream & out, const std::shared_ptr & context) const { + try { + do_render(out, context); + } catch (const std::runtime_error & e) { + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & location, std::vector> && c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + void do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::unique_ptr expr; +public: + ExpressionNode(const Location & location, std::unique_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector, std::unique_ptr>> cascade; +public: + IfNode(const Location & location, std::vector, std::unique_ptr>> && c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + branch.second->render(out, context); + return; + } + } + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::unique_ptr iterable; + std::unique_ptr condition; + std::unique_ptr body; + bool recursive; + std::unique_ptr else_body; +public: + ForNode(const Location & location, std::vector && var_names, std::unique_ptr && iterable, + std::unique_ptr && condition, std::unique_ptr && body, bool recursive, std::unique_ptr && else_body) + : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_array()) { + throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); + } + for (size_t i = 0, n = iter.size(); i < n; ++i) { + auto item = iter.at(i); + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + } + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, Value::Arguments & args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + body->render(out, loop_context); + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, Value::Arguments & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::unique_ptr name; + Expression::Parameters params; + std::unique_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & location, std::unique_ptr && n, Expression::Parameters && p, std::unique_ptr && b) + : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + auto callable = Value::callable([&](const std::shared_ptr & context, Value::Arguments & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (size_t i = 0, n = args.kwargs.size(); i < n; i++) { + auto & arg = args.kwargs[i]; + auto & arg_name = arg.first; + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, arg.second); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::unique_ptr value; + std::unique_ptr template_value; +public: + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::unique_ptr && v, std::unique_ptr && tv) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)), template_value(std::move(tv)) { + if (value && template_value) { + throw std::runtime_error("Cannot have both value and template value in set node"); + } + if (template_value && var_names.size() != 1) { + throw std::runtime_error("Destructuring assignment is only supported with a single variable name"); + } + } + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else if (template_value) { + Value value { template_value->render(context) }; + context->set(var_names[0], value); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class IfExpr : public Expression { + std::unique_ptr condition; + std::unique_ptr then_expr; + std::unique_ptr else_expr; +public: + IfExpr(const Location & location, std::unique_ptr && c, std::unique_ptr && t, std::unique_ptr && e) + : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & location, std::vector> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::unique_ptr>> elements; +public: + DictExpr(const Location & location, std::vector, std::unique_ptr>> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& e : elements) { + result.set(e.first->evaluate(context), e.second->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::unique_ptr start, end; + SliceExpr(const Location & location, std::unique_ptr && s, std::unique_ptr && e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr &) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::unique_ptr base; + std::unique_ptr index; +public: + SubscriptExpr(const Location & location, std::unique_ptr && b, std::unique_ptr && i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array"); + + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : target_value.size(); + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot }; +private: + std::unique_ptr expr; + Op op; +public: + UnaryOpExpr(const Location & location, std::unique_ptr && e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::unique_ptr left; + std::unique_ptr right; + Op op; +public: + BinaryOpExpr(const Location & location, std::unique_ptr && l, std::unique_ptr && r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_array(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return Value(true); + return right->evaluate(context).to_bool(); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, Value::Arguments & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +static std::string strip(const std::string & s) { + static std::regex trailing_spaces_regex("^\\s+|\\s+$"); + return std::regex_replace(s, trailing_spaces_regex, ""); +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::unique_ptr object; + std::unique_ptr method; + Expression::Arguments args; +public: + MethodCallExpr(const Location & location, std::unique_ptr && obj, std::unique_ptr && m, Expression::Arguments && a) + : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto obj = object->evaluate(context); + if (obj.is_array()) { + if (method->get_name() == "append") { + args.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(args.args[0]->evaluate(context)); + return Value(); + } else if (method->get_name() == "insert") { + args.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = args.args[0]->evaluate(context).get(); + if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, args.args[1]->evaluate(context)); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + args.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "get") { + args.expectArgs("get method", {1, 2}, {0, 0}); + auto key = args.args[0]->evaluate(context); + if (args.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : args.args[1]->evaluate(context); + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); + } + Value::Arguments vargs = args.evaluate(context); + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + if (method->get_name() == "strip") { + args.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(obj.get())); + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { +public: + std::unique_ptr object; + Expression::Arguments args; + CallExpr(const Location & location, std::unique_ptr && obj, Expression::Arguments && a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & location, std::vector> && p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + Value::Arguments args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + Value::Arguments args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::unique_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { +private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return nonstd_make_unique(result); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::unique_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return nonstd_make_unique(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return nonstd_make_unique(true); + if (token == "false" || token == "False") return nonstd_make_unique(false); + if (token == "None") return nonstd_make_unique(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return nonstd_make_unique(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::unique_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto if_expr = parseIfExpression(); + return nonstd_make_unique(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::unique_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::unique_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) throw std::runtime_error("Expected 'else' expression"); + } + return std::make_pair(std::move(condition), std::move(else_expr)); + } + + std::unique_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'or' expression"); + left = nonstd_make_unique(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::unique_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); + return nonstd_make_unique(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::unique_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) throw std::runtime_error("Expected right side of 'and' expression"); + left = nonstd_make_unique(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::unique_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\n\s]+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); + + return nonstd_make_unique( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else throw std::runtime_error("Unknown comparison operator: " + op_str); + left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + Expression::Arguments parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); + + Expression::Arguments result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::unique_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!not|is|and|or|del)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return nonstd_make_unique(location, ident); + } + + std::unique_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); + left = nonstd_make_unique(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::unique_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); + left = nonstd_make_unique(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::unique_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::unique_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return nonstd_make_unique(get_location(), std::move(parts)); + } + } + return left; + } + + std::unique_ptr call_func(const std::string & name, Expression::Arguments && args) const { + return nonstd_make_unique(get_location(), nonstd_make_unique(get_location(), name), std::move(args)); + } + + std::unique_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseValueExpression(); + if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return nonstd_make_unique(get_location(), std::move(expr), op); + } + return expr; + } + + std::unique_ptr parseValueExpression() { + auto parseValue = [&]() -> std::unique_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return nonstd_make_unique(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return nonstd_make_unique(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::unique_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = nonstd_make_unique(slice_end->location, nullptr, std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({ "]" })) { + index = nonstd_make_unique(slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = nonstd_make_unique(slice_start->location, std::move(slice_start), std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = nonstd_make_unique(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = nonstd_make_unique(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = nonstd_make_unique(identifier->location, Value(identifier->get_name())); + value = nonstd_make_unique(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = nonstd_make_unique(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::unique_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return nonstd_make_unique(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::unique_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return nonstd_make_unique(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return nonstd_make_unique(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::unique_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::unique_ptr>> elements; + if (!consumeToken("}").empty()) { + return nonstd_make_unique(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::make_pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return nonstd_make_unique(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:[\n\s]*,[\n\s]*(?:\w+))*)[\n\s]*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken & token) const { + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken & token) const { + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?[\s\n]*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro)\b)"); + static std::regex text_regex(R"([\s\S\n]*?($|(?=\{\{|\{%|\{#)))"); + static std::regex expr_close_regex(R"([\s\n]*([-~])?\}\})"); + static std::regex block_close_regex(R"([\s\n]*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) throw std::runtime_error("Expected iterable in for block"); + + std::unique_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)[\s\n]*\.[\s\n]*(\w+))"); + + std::string ns; + std::vector var_names; + std::unique_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(nonstd_make_unique(location, pre_space, post_space)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (!(text = consumeToken(text_regex, SpaceHandling::Keep)).empty()) { + tokens.push_back(nonstd_make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + if (it != end) throw std::runtime_error("Unexpected character"); + } + } + return tokens; + } catch (const std::runtime_error & e) { + throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::unique_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, std::unique_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back(nonstd_make_unique(token->location, std::move(cascade))); + } else if (auto for_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::unique_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(nonstd_make_unique(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (auto text_token = dynamic_cast(token.get())) { + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { + static std::regex leading_line(R"(^[ \t]*\n)"); + text = std::regex_replace(text, leading_line, ""); + } + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + static std::regex trailing_last_line_space_regex(R"((^|\n)[ \t]*$)"); + text = std::regex_replace(text, trailing_last_line_space_regex, "$1"); + } + + if (it == end && !options.keep_trailing_newline) { + static std::regex r(R"([\n\r]$)"); + text = std::regex_replace(text, r, ""); // Strip one trailing newline + } + children.emplace_back(nonstd_make_unique(token->location, text)); + } else if (auto expr_token = dynamic_cast(token.get())) { + children.emplace_back(nonstd_make_unique(token->location, std::move(expr_token->expr))); + } else if (auto set_token = dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(nonstd_make_unique(token->location, set_token->ns, set_token->var_names, std::move(set_token->value), nullptr)); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + children.emplace_back(nonstd_make_unique(token->location, set_token->ns, set_token->var_names, nullptr, std::move(value_template))); + } + } else if (auto macro_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(nonstd_make_unique(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto comment_token = dynamic_cast(token.get())) { + // Ignore comments + } else if (dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it-1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return nonstd_make_unique(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return nonstd_make_unique(children[0]->location(), std::move(children)); + } + } + +public: + + static std::unique_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(template_str), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, Value::Arguments & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (size_t i = 0, n = args.kwargs.size(); i < n; i++) { + auto & arg = args.kwargs[i]; + auto named_pos_it = named_positions.find(arg.first); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + arg.first + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(arg.first, arg.second); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not a list"); + if (items.size() == 0) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, Value::Arguments & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits::max()}); + for (auto & arg : args.kwargs) { + ns.set(arg.first, arg.second); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + Value::Arguments actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject + globals.set("reject", Value::callable([=](const std::shared_ptr & context, Value::Arguments & args) { + args.expectArgs("reject", {2, std::numeric_limits::max()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + Value::Arguments filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (!pred_res.to_bool()) { + res.push_back(item); + } + } + return res; + })); + globals.set("map", Value::callable([=](const std::shared_ptr & context, Value::Arguments & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + Value::Arguments filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, Value::Arguments & args) { + args.expectArgs("selectattr", {2, std::numeric_limits::max()}, {0, 0}); + auto & items = args.args[0]; + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + Value::Arguments test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool()) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + })); + globals.set("range", Value::callable([=](const std::shared_ptr &, Value::Arguments & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & arg : args.kwargs) { + size_t i; + if (arg.first == "start") i = 0; + else if (arg.first == "end") i = 1; + else if (arg.first == "step") i = 2; + else throw std::runtime_error("Unknown argument " + arg.first + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + arg.first + " for function range"); + } + startEndStep[i] = arg.second.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 25f2489961b90..86705386a0d61 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -123,6 +123,7 @@ llama_target_and_test(test-barrier.cpp) # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-backend-ops.cpp) llama_target_and_test(test-antiprompts.cpp) +llama_target_and_test(test-minja.cpp) llama_target_and_test(test-rope.cpp) diff --git a/tests/chat/contexts/simple.json b/tests/chat/contexts/simple.json new file mode 100644 index 0000000000000..fa4877616dcef --- /dev/null +++ b/tests/chat/contexts/simple.json @@ -0,0 +1,15 @@ +{ + "messages": [ + { + "role": "user", + "content": "What's your favourite LLM framework?" + }, + { + "role": "assistant", + "content": "llama.cpp!" + } + ], + "add_generation_prompt": true, + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>" +} \ No newline at end of file diff --git a/tests/chat/contexts/system.json b/tests/chat/contexts/system.json new file mode 100644 index 0000000000000..9c016f36910c6 --- /dev/null +++ b/tests/chat/contexts/system.json @@ -0,0 +1,19 @@ +{ + "messages": [ + { + "role": "system", + "content": "You only tell the truth." + }, + { + "role": "user", + "content": "What's your favourite LLM framework?" + }, + { + "role": "assistant", + "content": "llama.cpp!" + } + ], + "add_generation_prompt": true, + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>" +} \ No newline at end of file diff --git a/tests/chat/contexts/tool_use.json b/tests/chat/contexts/tool_use.json new file mode 100644 index 0000000000000..6345ef24b7876 --- /dev/null +++ b/tests/chat/contexts/tool_use.json @@ -0,0 +1,164 @@ +{ + "messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "arguments": {"code": "print('Hello, World!')"}, + "name": "ipython" + } + } + ] + }, + { + "role": "tool", + "name": "ipython", + "content": "{\"stdout\": \"Hello, World!\"}" + }, + { + "role": "assistant", + "content": "Anything else?" + }, + { + "role": "user", + "content": "Test a tautology." + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": { + "arguments": {"condition":true}, + "name": "test" + } + } + ] + }, + { + "role": "tool", + "name": "test", + "content": "true" + }, + { + "role": "assistant", + "content": "Truth is definitely true." + }, + { + "role": "user", + "content": "Check it on the web." + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_3", + "type": "function", + "function": { + "arguments": {"query": "what is truth anyway am I right?"}, + "name": "brave_search" + } + } + ] + }, + { + "role": "tool", + "name": "brave_search", + "content": "{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}" + }, + { + "role": "assistant", + "content": "I don't need the web to answer you but I did check, as you asked. What now?" + } + ], + "add_generation_prompt": true, + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + "builtin_tools": [ + "wolfram_alpha", + "brave_search" + ], + "cutting_knowledge_date": "2023-04-01", + "todays_date": "2024-09-03", + "tools": [ + { + "type": "function", + "function": { + "name": "ipython", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } + }, + { + "type": "function", + "function": { + "name": "brave_search", + "description": "Executes a web search with Brave.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search for." + } + }, + "required": ["query"] + } + } + }, + { + "type": "function", + "function": { + "name": "wolfram_alpha", + "description": "Executes a query with Wolfram Alpha.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to execute." + } + }, + "required": ["query"] + } + } + }, + { + "type": "function", + "function": { + "name": "test", + "description": "Runs a test.", + "parameters": { + "type": "object", + "properties": { + "condition": { + "type": "boolean", + "description": "The condition to test." + } + }, + "required": ["condition"] + } + } + } + ] +} \ No newline at end of file diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-simple.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-simple.txt new file mode 100644 index 0000000000000..8824912a4cbc2 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-simple.txt @@ -0,0 +1,5 @@ +<|startoftext|><|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-system.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-system.txt new file mode 100644 index 0000000000000..eed13ce3d2ea0 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-default-system.txt @@ -0,0 +1,7 @@ +<|startoftext|><|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-simple.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-simple.txt new file mode 100644 index 0000000000000..6a8b5a5c86d89 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-simple.txt @@ -0,0 +1,11 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-system.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-system.txt new file mode 100644 index 0000000000000..9435ec9b7f1e6 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-system.txt @@ -0,0 +1,13 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-tool_use.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-tool_use.txt new file mode 100644 index 0000000000000..07e2883f450b2 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use-tool_use.txt @@ -0,0 +1,58 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": "ipython", "description": "ipython(code: str) - Runs code in an ipython interpreter and returns the result of the execution after 60 seconds. + + Args: + code(str): The code to run in the ipython interpreter.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}} +{"type": "function", "function": {"name": "brave_search", "description": "brave_search(query: str) - Executes a web search with Brave. + + Args: + query(str): The query to search for.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "wolfram_alpha", "description": "wolfram_alpha(query: str) - Executes a query with Wolfram Alpha. + + Args: + query(str): The query to execute.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "test", "description": "test(condition: bool) - Runs a test. + + Args: + condition(bool): The condition to test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +Print a hello world message with python.<|im_end|> +<|im_start|>assistant + +{"name": "ipython", "arguments": {"code": "print('Hello, World!')"}} +<|im_end|> +<|im_start|>tool + +{"stdout": "Hello, World!"} + +<|im_end|><|im_start|>assistant +Anything else?<|im_end|> +<|im_start|>user +Test a tautology.<|im_end|> +<|im_start|>assistant + +{"name": "test", "arguments": {"condition": true}} +<|im_end|> +<|im_start|>tool + +true + +<|im_end|><|im_start|>assistant +Truth is definitely true.<|im_end|> +<|im_start|>user +Check it on the web.<|im_end|> +<|im_start|>assistant + +{"name": "brave_search", "arguments": {"query": "what is truth anyway am I right?"}} +<|im_end|> +<|im_start|>tool + +{"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"} + +<|im_end|><|im_start|>assistant +I don't need the web to answer you but I did check, as you asked. What now?<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-simple.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-simple.txt new file mode 100644 index 0000000000000..8824912a4cbc2 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-simple.txt @@ -0,0 +1,5 @@ +<|startoftext|><|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-system.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-system.txt new file mode 100644 index 0000000000000..eed13ce3d2ea0 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-default-system.txt @@ -0,0 +1,7 @@ +<|startoftext|><|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-simple.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-simple.txt new file mode 100644 index 0000000000000..6a8b5a5c86d89 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-simple.txt @@ -0,0 +1,11 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-system.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-system.txt new file mode 100644 index 0000000000000..9435ec9b7f1e6 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-system.txt @@ -0,0 +1,13 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-tool_use.txt b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-tool_use.txt new file mode 100644 index 0000000000000..07e2883f450b2 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use-tool_use.txt @@ -0,0 +1,58 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": "ipython", "description": "ipython(code: str) - Runs code in an ipython interpreter and returns the result of the execution after 60 seconds. + + Args: + code(str): The code to run in the ipython interpreter.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}} +{"type": "function", "function": {"name": "brave_search", "description": "brave_search(query: str) - Executes a web search with Brave. + + Args: + query(str): The query to search for.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "wolfram_alpha", "description": "wolfram_alpha(query: str) - Executes a query with Wolfram Alpha. + + Args: + query(str): The query to execute.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "test", "description": "test(condition: bool) - Runs a test. + + Args: + condition(bool): The condition to test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +Print a hello world message with python.<|im_end|> +<|im_start|>assistant + +{"name": "ipython", "arguments": {"code": "print('Hello, World!')"}} +<|im_end|> +<|im_start|>tool + +{"stdout": "Hello, World!"} + +<|im_end|><|im_start|>assistant +Anything else?<|im_end|> +<|im_start|>user +Test a tautology.<|im_end|> +<|im_start|>assistant + +{"name": "test", "arguments": {"condition": true}} +<|im_end|> +<|im_start|>tool + +true + +<|im_end|><|im_start|>assistant +Truth is definitely true.<|im_end|> +<|im_start|>user +Check it on the web.<|im_end|> +<|im_start|>assistant + +{"name": "brave_search", "arguments": {"query": "what is truth anyway am I right?"}} +<|im_end|> +<|im_start|>tool + +{"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"} + +<|im_end|><|im_start|>assistant +I don't need the web to answer you but I did check, as you asked. What now?<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-simple.txt b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-simple.txt new file mode 100644 index 0000000000000..558a5087dba5b --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-simple.txt @@ -0,0 +1,7 @@ +<|startoftext|><|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-system.txt b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-system.txt new file mode 100644 index 0000000000000..eed13ce3d2ea0 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-default-system.txt @@ -0,0 +1,7 @@ +<|startoftext|><|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-simple.txt b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-simple.txt new file mode 100644 index 0000000000000..6a8b5a5c86d89 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-simple.txt @@ -0,0 +1,11 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-system.txt b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-system.txt new file mode 100644 index 0000000000000..9435ec9b7f1e6 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-system.txt @@ -0,0 +1,13 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-tool_use.txt b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-tool_use.txt new file mode 100644 index 0000000000000..07e2883f450b2 --- /dev/null +++ b/tests/chat/goldens/NousResearch-Hermes-3-Llama-3.1-70B-tool_use-tool_use.txt @@ -0,0 +1,58 @@ +<|startoftext|><|im_start|>system +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": "ipython", "description": "ipython(code: str) - Runs code in an ipython interpreter and returns the result of the execution after 60 seconds. + + Args: + code(str): The code to run in the ipython interpreter.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}} +{"type": "function", "function": {"name": "brave_search", "description": "brave_search(query: str) - Executes a web search with Brave. + + Args: + query(str): The query to search for.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "wolfram_alpha", "description": "wolfram_alpha(query: str) - Executes a query with Wolfram Alpha. + + Args: + query(str): The query to execute.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}} +{"type": "function", "function": {"name": "test", "description": "test(condition: bool) - Runs a test. + + Args: + condition(bool): The condition to test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +Print a hello world message with python.<|im_end|> +<|im_start|>assistant + +{"name": "ipython", "arguments": {"code": "print('Hello, World!')"}} +<|im_end|> +<|im_start|>tool + +{"stdout": "Hello, World!"} + +<|im_end|><|im_start|>assistant +Anything else?<|im_end|> +<|im_start|>user +Test a tautology.<|im_end|> +<|im_start|>assistant + +{"name": "test", "arguments": {"condition": true}} +<|im_end|> +<|im_start|>tool + +true + +<|im_end|><|im_start|>assistant +Truth is definitely true.<|im_end|> +<|im_start|>user +Check it on the web.<|im_end|> +<|im_start|>assistant + +{"name": "brave_search", "arguments": {"query": "what is truth anyway am I right?"}} +<|im_end|> +<|im_start|>tool + +{"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"} + +<|im_end|><|im_start|>assistant +I don't need the web to answer you but I did check, as you asked. What now?<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-simple.txt b/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-simple.txt new file mode 100644 index 0000000000000..1d9ab01acec3d --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-simple.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-system.txt b/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-system.txt new file mode 100644 index 0000000000000..e3a52d4de912e --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2-7B-Instruct-system.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-simple.txt b/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-simple.txt new file mode 100644 index 0000000000000..1d9ab01acec3d --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-simple.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-system.txt b/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-system.txt new file mode 100644 index 0000000000000..e3a52d4de912e --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2-VL-7B-Instruct-system.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-simple.txt b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-simple.txt new file mode 100644 index 0000000000000..b6e30b122d617 --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-simple.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-system.txt b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-system.txt new file mode 100644 index 0000000000000..e3a52d4de912e --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-system.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-tool_use.txt b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-tool_use.txt new file mode 100644 index 0000000000000..7862ad435857f --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-7B-Instruct-tool_use.txt @@ -0,0 +1,56 @@ +<|im_start|>system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}}} +{"type": "function", "function": {"name": "brave_search", "description": "Executes a web search with Brave.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "test", "description": "Runs a test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +Print a hello world message with python.<|im_end|> +<|im_start|>assistant + +{"name": "ipython", "arguments": {"code": "print('Hello, World!')"}} +<|im_end|> +<|im_start|>user + +{"stdout": "Hello, World!"} +<|im_end|> +<|im_start|>assistant +Anything else?<|im_end|> +<|im_start|>user +Test a tautology.<|im_end|> +<|im_start|>assistant + +{"name": "test", "arguments": {"condition": true}} +<|im_end|> +<|im_start|>user + +true +<|im_end|> +<|im_start|>assistant +Truth is definitely true.<|im_end|> +<|im_start|>user +Check it on the web.<|im_end|> +<|im_start|>assistant + +{"name": "brave_search", "arguments": {"query": "what is truth anyway am I right?"}} +<|im_end|> +<|im_start|>user + +{"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"} +<|im_end|> +<|im_start|>assistant +I don't need the web to answer you but I did check, as you asked. What now?<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-simple.txt b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-simple.txt new file mode 100644 index 0000000000000..ce7ae7d425b4d --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-simple.txt @@ -0,0 +1,7 @@ +<|im_start|>system +Please reason step by step, and put your final answer within \boxed{}.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-system.txt b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-system.txt new file mode 100644 index 0000000000000..e3a52d4de912e --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-system.txt @@ -0,0 +1,7 @@ +<|im_start|>system +You only tell the truth.<|im_end|> +<|im_start|>user +What's your favourite LLM framework?<|im_end|> +<|im_start|>assistant +llama.cpp!<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-tool_use.txt b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-tool_use.txt new file mode 100644 index 0000000000000..b25b2054faccd --- /dev/null +++ b/tests/chat/goldens/Qwen-Qwen2.5-Math-7B-Instruct-tool_use.txt @@ -0,0 +1,56 @@ +<|im_start|>system +Please reason step by step, and put your final answer within \boxed{}. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}}} +{"type": "function", "function": {"name": "brave_search", "description": "Executes a web search with Brave.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "test", "description": "Runs a test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +Print a hello world message with python.<|im_end|> +<|im_start|>assistant + +{"name": "ipython", "arguments": {"code": "print('Hello, World!')"}} +<|im_end|> +<|im_start|>user + +{"stdout": "Hello, World!"} +<|im_end|> +<|im_start|>assistant +Anything else?<|im_end|> +<|im_start|>user +Test a tautology.<|im_end|> +<|im_start|>assistant + +{"name": "test", "arguments": {"condition": true}} +<|im_end|> +<|im_start|>user + +true +<|im_end|> +<|im_start|>assistant +Truth is definitely true.<|im_end|> +<|im_start|>user +Check it on the web.<|im_end|> +<|im_start|>assistant + +{"name": "brave_search", "arguments": {"query": "what is truth anyway am I right?"}} +<|im_end|> +<|im_start|>user + +{"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"} +<|im_end|> +<|im_start|>assistant +I don't need the web to answer you but I did check, as you asked. What now?<|im_end|> +<|im_start|>assistant diff --git a/tests/chat/goldens/google-gemma-2-2b-it-simple.txt b/tests/chat/goldens/google-gemma-2-2b-it-simple.txt new file mode 100644 index 0000000000000..014eb2e8089c2 --- /dev/null +++ b/tests/chat/goldens/google-gemma-2-2b-it-simple.txt @@ -0,0 +1,5 @@ +<|startoftext|>user +What's your favourite LLM framework? +model +llama.cpp! +model diff --git a/tests/chat/goldens/meetkai-functionary-medium-v3.2-simple.txt b/tests/chat/goldens/meetkai-functionary-medium-v3.2-simple.txt new file mode 100644 index 0000000000000..3c20de4f5daad --- /dev/null +++ b/tests/chat/goldens/meetkai-functionary-medium-v3.2-simple.txt @@ -0,0 +1,21 @@ +<|startoftext|><|start_header_id|>system<|end_header_id|> + +You are capable of executing available function(s) if required. +Only execute function(s) when absolutely necessary. +Ask for the required input to:recipient==all +Use JSON for function arguments. +Respond in this format: +>>>${recipient} +${content} +Available functions: +// Supported function definitions that should be called when necessary. +namespace functions { + +} // namespace functions<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's your favourite LLM framework?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +>>>all +llama.cpp!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +>>> \ No newline at end of file diff --git a/tests/chat/goldens/meetkai-functionary-medium-v3.2-system.txt b/tests/chat/goldens/meetkai-functionary-medium-v3.2-system.txt new file mode 100644 index 0000000000000..a006497cf1f6f --- /dev/null +++ b/tests/chat/goldens/meetkai-functionary-medium-v3.2-system.txt @@ -0,0 +1,23 @@ +<|startoftext|><|start_header_id|>system<|end_header_id|> + +You are capable of executing available function(s) if required. +Only execute function(s) when absolutely necessary. +Ask for the required input to:recipient==all +Use JSON for function arguments. +Respond in this format: +>>>${recipient} +${content} +Available functions: +// Supported function definitions that should be called when necessary. +namespace functions { + +} // namespace functions<|eot_id|><|start_header_id|>system<|end_header_id|> + +You only tell the truth.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's your favourite LLM framework?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +>>>all +llama.cpp!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +>>> \ No newline at end of file diff --git a/tests/chat/goldens/meetkai-functionary-medium-v3.2-tool_use.txt b/tests/chat/goldens/meetkai-functionary-medium-v3.2-tool_use.txt new file mode 100644 index 0000000000000..2cc3c7a8e6c1c --- /dev/null +++ b/tests/chat/goldens/meetkai-functionary-medium-v3.2-tool_use.txt @@ -0,0 +1 @@ +ERROR: can only concatenate str (not "dict") to str \ No newline at end of file diff --git a/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-simple.txt b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-simple.txt new file mode 100644 index 0000000000000..23b6fcde3de1f --- /dev/null +++ b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-simple.txt @@ -0,0 +1,11 @@ +<|startoftext|><|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 +Today Date: 26 Jul 2024 + +<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's your favourite LLM framework?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +llama.cpp!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-system.txt b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-system.txt new file mode 100644 index 0000000000000..8d257a035a2bf --- /dev/null +++ b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-system.txt @@ -0,0 +1,11 @@ +<|startoftext|><|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 +Today Date: 26 Jul 2024 + +You only tell the truth.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's your favourite LLM framework?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +llama.cpp!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-tool_use.txt b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-tool_use.txt new file mode 100644 index 0000000000000..0c2c6a921f583 --- /dev/null +++ b/tests/chat/goldens/meta-llama-Meta-Llama-3.1-8B-Instruct-tool_use.txt @@ -0,0 +1,118 @@ +<|startoftext|><|start_header_id|>system<|end_header_id|> + +Environment: ipython +Tools: wolfram_alpha, brave_search + +Cutting Knowledge Date: December 2023 +Today Date: 26 Jul 2024 + +<|eot_id|><|start_header_id|>user<|end_header_id|> + +Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + +{ + "type": "function", + "function": { + "name": "ipython", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": [ + "code" + ] + } + } +} + +{ + "type": "function", + "function": { + "name": "brave_search", + "description": "Executes a web search with Brave.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search for." + } + }, + "required": [ + "query" + ] + } + } +} + +{ + "type": "function", + "function": { + "name": "wolfram_alpha", + "description": "Executes a query with Wolfram Alpha.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to execute." + } + }, + "required": [ + "query" + ] + } + } +} + +{ + "type": "function", + "function": { + "name": "test", + "description": "Runs a test.", + "parameters": { + "type": "object", + "properties": { + "condition": { + "type": "boolean", + "description": "The condition to test." + } + }, + "required": [ + "condition" + ] + } + } +} + +Print a hello world message with python.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{"name": "ipython", "parameters": {"code": "print('Hello, World!')"}}<|eom_id|><|start_header_id|>ipython<|end_header_id|> + +"{\"stdout\": \"Hello, World!\"}"<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +Anything else?<|eot_id|><|start_header_id|>user<|end_header_id|> + +Test a tautology.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{"name": "test", "parameters": {"condition": true}}<|eom_id|><|start_header_id|>ipython<|end_header_id|> + +"true"<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +Truth is definitely true.<|eot_id|><|start_header_id|>user<|end_header_id|> + +Check it on the web.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +<|python_tag|>brave_search.call(query="what is truth anyway am I right?")<|eom_id|><|start_header_id|>ipython<|end_header_id|> + +"{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}"<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +I don't need the web to answer you but I did check, as you asked. What now?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + diff --git a/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-simple.txt b/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-simple.txt new file mode 100644 index 0000000000000..a7f52dec6f9b0 --- /dev/null +++ b/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-simple.txt @@ -0,0 +1,5 @@ +<|user|> +What's your favourite LLM framework?<|end|> +<|assistant|> +llama.cpp!<|end|> +<|assistant|> diff --git a/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-system.txt b/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-system.txt new file mode 100644 index 0000000000000..2d32334ec616d --- /dev/null +++ b/tests/chat/goldens/microsoft-Phi-3.5-mini-instruct-system.txt @@ -0,0 +1,7 @@ +<|system|> +You only tell the truth.<|end|> +<|user|> +What's your favourite LLM framework?<|end|> +<|assistant|> +llama.cpp!<|end|> +<|assistant|> diff --git a/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-simple.txt b/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-simple.txt new file mode 100644 index 0000000000000..baf3e9057141c --- /dev/null +++ b/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-simple.txt @@ -0,0 +1 @@ +<|startoftext|> [INST] What's your favourite LLM framework? [/INST] llama.cpp!<|endoftext|> \ No newline at end of file diff --git a/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-system.txt b/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-system.txt new file mode 100644 index 0000000000000..3321c8b75c31d --- /dev/null +++ b/tests/chat/goldens/mistralai-Mixtral-8x7B-Instruct-v0.1-system.txt @@ -0,0 +1,3 @@ +<|startoftext|> [INST] You only tell the truth. + +What's your favourite LLM framework? [/INST] llama.cpp!<|endoftext|> \ No newline at end of file diff --git a/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja b/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja new file mode 100644 index 0000000000000..463f9fd74cdde --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja @@ -0,0 +1,4 @@ +{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja b/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja new file mode 100644 index 0000000000000..149250bd540aa --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja b/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja new file mode 100644 index 0000000000000..463f9fd74cdde --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja @@ -0,0 +1,4 @@ +{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja b/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja new file mode 100644 index 0000000000000..149250bd540aa --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-default.jinja b/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-default.jinja new file mode 100644 index 0000000000000..744756d517615 --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-default.jinja @@ -0,0 +1,6 @@ +{{bos_token}}{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +You are a helpful assistant.<|im_end|> +' }}{% endif %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja b/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja new file mode 100644 index 0000000000000..149250bd540aa --- /dev/null +++ b/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/tests/chat/templates/Qwen-Qwen2-7B-Instruct.jinja b/tests/chat/templates/Qwen-Qwen2-7B-Instruct.jinja new file mode 100644 index 0000000000000..a4c0b5993f324 --- /dev/null +++ b/tests/chat/templates/Qwen-Qwen2-7B-Instruct.jinja @@ -0,0 +1,6 @@ +{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +You are a helpful assistant.<|im_end|> +' }}{% endif %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/Qwen-Qwen2-VL-7B-Instruct.jinja b/tests/chat/templates/Qwen-Qwen2-VL-7B-Instruct.jinja new file mode 100644 index 0000000000000..6c226632394ae --- /dev/null +++ b/tests/chat/templates/Qwen-Qwen2-VL-7B-Instruct.jinja @@ -0,0 +1,7 @@ +{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system +You are a helpful assistant.<|im_end|> +{% endif %}<|im_start|>{{ message['role'] }} +{% if message['content'] is string %}{{ message['content'] }}<|im_end|> +{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> +{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant +{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja b/tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja new file mode 100644 index 0000000000000..bdf7919a96cfe --- /dev/null +++ b/tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/tests/chat/templates/Qwen-Qwen2.5-Math-7B-Instruct.jinja b/tests/chat/templates/Qwen-Qwen2.5-Math-7B-Instruct.jinja new file mode 100644 index 0000000000000..11f6d3214a18e --- /dev/null +++ b/tests/chat/templates/Qwen-Qwen2.5-Math-7B-Instruct.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'Please reason step by step, and put your final answer within \\boxed{}.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/tests/chat/templates/google-gemma-2-2b-it.jinja b/tests/chat/templates/google-gemma-2-2b-it.jinja new file mode 100644 index 0000000000000..923ec253c8dbe --- /dev/null +++ b/tests/chat/templates/google-gemma-2-2b-it.jinja @@ -0,0 +1,4 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + ' +' + message['content'] | trim + ' +' }}{% endfor %}{% if add_generation_prompt %}{{'model +'}}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/meetkai-functionary-medium-v3.2.jinja b/tests/chat/templates/meetkai-functionary-medium-v3.2.jinja new file mode 100644 index 0000000000000..74fd1e7af6f37 --- /dev/null +++ b/tests/chat/templates/meetkai-functionary-medium-v3.2.jinja @@ -0,0 +1,287 @@ +{# version=v3.llama3 #}{%- macro append_new_param_info(param_declaration, comment_info, examples_info, depth) -%} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- if examples_info | length > 0 -%} + {# Append each example info #} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- endif -%} + {{ "\n" + offset + param_declaration }} +{%- endmacro -%} + +{%- macro convert_data_type(param_type) -%} + {%- if param_type == "integer" or param_type == "float" -%} + {{ "number" }} + {%- else -%} + {{ param_type }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_type(param) -%} + {%- set param_type = "any" -%} + + {%- if "type" in param -%} + {%- set raw_param_type = param["type"] -%} + {%- if raw_param_type is iterable and raw_param_type is not string -%} + {%- set param_type = raw_param_type | join(" | ") -%} + {%- else -%} + {%- set param_type = raw_param_type -%} + {%- endif -%} + {{ convert_data_type(param_type) }} + {%- elif "oneOf" in param -%} + {%- set one_of_types = param["oneOf"]|selectattr("type", "defined")|list -%} + {%- set one_of_types = one_of_types|map(attribute="type")|unique|list -%} + {{ convert_data_type(one_of_types | join(" | ")) }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_format_param(param) -%} + {%- if "format" in param -%} + {{ param["format"] }} + {%- elif "oneOf" in param -%} + {%- set formats = [] -%} + {%- for item in param["oneOf"] -%} + {%- if "format" in item -%} + {%- if item["format"] == param["oneOf"][-1]["format"] -%} + {{ item["format"] }} + {%- else -%} + {{ item["format"] + " or "}} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_info(param) -%} + {%- set param_type = param.get("type", "any") -%} + {%- set format_param = get_format_param(param) -%} + + {%- if "description" in param or "default" in param or format_param != "<|NONE|>" or param["maximum"] or param["minimum"] or param["maxLength"] or param["minLength"] -%} + {{ "//" }} + {%- if "description" in param -%} + {%- set desc = param["description"] -%} + {%- if not desc.endswith(".") -%} + {%- set desc = desc + "." -%} + {%- endif -%} + {{ " " + desc }} + {%- endif -%} + + {%- if "default" in param -%} + {%- set default_value = param["default"] -%} + {%- if param_type == "string" -%} + {%- set default_value = '"' ~ default_value ~ '"' -%} + {%- endif -%} + {{ " Default=" ~ default_value ~ "." }} + {%- endif -%} + + {%- set format_param = get_format_param(param) -%} + {%- if format_param != "<|NONE|>" -%} + {{ " Format=" ~ format_param }} + {%- endif -%} + + {%- for field, field_name in [("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length")] -%} + {%- if field in param -%} + {{ " " + field_name ~ "=" ~ param[field] }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>"}} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_enum_option_str(enum_options) -%} + {%- for v in enum_options -%} + {%- if v is string -%} + {{ '"' + v + '"' }} + {%- else -%} + {{ v }} + {%- endif -%} + {%- if enum_options|length > 0 and v != enum_options[-1] -%} + {{ " | " }} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro get_array_typescript(param_name, param_dic, depth) -%} + {%- set offset = '' -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- set items_info = param_dic.get('items', {}) -%} + + {%- if items_info|length == 0 -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": []" }} + {%- else -%} + {{ "\n" + offset + "[]" }} + {%- endif -%} + {%- else -%} + {%- set array_type = get_param_type(items_info) -%} + {%- if array_type == 'object' -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": {" }} + {%- else -%} + {{ "\n" + offset + "{" }} + {%- endif -%} + {{ get_parameter_typescript(items_info.get('properties', {}), items_info.get('required', []), depth + 1) -}} + {{- "\n" + offset + "}[]" }} + {%- elif array_type == 'array' -%} + {%- set item_info = get_array_typescript(None, items_info, depth + 1) -%} + {%- if not param_name -%} + {{ "\n" + item_info + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + item_info|trim + "[]" }} + {%- endif -%} + {%- else -%} + {%- if 'enum' in items_info -%} + {%- set item_type = get_enum_option_str(items_info['enum']) -%} + {%- if param_name is none -%} + {{ "(" + item_type + ")[]"}} + {%- else -%} + {{ "\n" + offset + param_name + ": (" + item_type + ")[]" }} + {%- endif -%} + {%- else -%} + {%- if param_name is none -%} + {{ "\n" + array_type + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + array_type + "[]," }} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_parameter_typescript(properties, required_params, depth=0) -%} + {%- set res = "" -%} + {%- for param_name, param in properties.items() -%} + {%- if param is mapping -%} + {%- set comment_info = get_param_info(param) -%} + {# Param Examples #} + {%- set examples_info = [] -%} + {%- if "examples" in param -%} + {%- set examples_info = ["Example " + param_name + ":"] -%} + {%- set examples_info = examples_info + param["examples"] -%} + {%- endif -%} + + {# Param Name declaration #} + {%- set param_declaration = param_name -%} + {%- if required_params is iterable and param_name not in required_params -%} + {%- set param_declaration = param_declaration + "?" -%} + {%- endif -%} + + {%- set param_type = get_param_type(param) -%} + + {# Handle indentation based on depth #} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + + {%- if param_type == "object" -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": {" -%} + {{ "\n" + offset + param_declaration -}} + {{- get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1) -}} + {{- "\n" + offset + "}," }} + {%- elif param_type == "array" -%} + {%- set item_info = param.get("items", {}) -%} + {%- if "type" not in item_info -%} + {%- set param_declaration = param_declaration + ": []," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- else -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set array_declaration = get_array_typescript(param_declaration, param, depth) -%} + {%- if not array_declaration.endswith(",") -%} + {%- set array_declaration = array_declaration + "," -%} + {%- endif -%} + {{ array_declaration}} + {%- endif -%} + {%- else -%} + {%- if "enum" in param -%} + {%- set param_type = get_enum_option_str(param["enum"]) -%} + {%- endif -%} + {%- if "nullable" in param and param["nullable"] -%} + {%- set param_type = param_type + " | null" -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": " + param_type + "," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro generate_schema_from_functions(functions, namespace='functions') -%} + {{ "// Supported function definitions that should be called when necessary.\n" -}} + {{- "namespace " + namespace + " {\n\n" -}} + + {%- for function in functions -%} + {%- if function.get("function") -%} + {%- set function = function.get("function") -%} + {%- endif -%} + + {%- set function_name = function.get("name") -%} + {%- if function_name -%} + {%- set description = function.get('description', '') -%} + {%- set parameters = function.get('parameters', {}) -%} + {{- "// " + description + "\n" -}} + {{- "type " + function_name -}} + {%- if parameters and parameters.get("properties") -%} + {{- " = (_: {" -}} + {%- set required_params = parameters.get("required", []) -%} + {{ get_parameter_typescript(parameters.get("properties"), required_params, 0) -}} + {{- "\n}) => any;\n\n" }} + {%- else -%} + {{ " = () => any;\n\n" }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {{ "} // namespace " + namespace }} +{%- endmacro -%} +{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ '>>>all\n' + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja b/tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja new file mode 100644 index 0000000000000..33089ace1be88 --- /dev/null +++ b/tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja b/tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja new file mode 100644 index 0000000000000..d1533d1526b2e --- /dev/null +++ b/tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja @@ -0,0 +1,8 @@ +{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% else %}{{ eos_token }}{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/mistralai-Mixtral-8x7B-Instruct-v0.1.jinja b/tests/chat/templates/mistralai-Mixtral-8x7B-Instruct-v0.1.jinja new file mode 100644 index 0000000000000..40b37ad7f90d4 --- /dev/null +++ b/tests/chat/templates/mistralai-Mixtral-8x7B-Instruct-v0.1.jinja @@ -0,0 +1,24 @@ +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if message['role'] == 'user' %} + {%- if loop.first and system_message is defined %} + {{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }} + {%- else %} + {{- ' [INST] ' + message['content'] + ' [/INST]' }} + {%- endif %} + {%- elif message['role'] == 'assistant' %} + {{- ' ' + message['content'] + eos_token}} + {%- else %} + {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }} + {%- endif %} +{%- endfor %} diff --git a/tests/test-minja.cpp b/tests/test-minja.cpp new file mode 100644 index 0000000000000..ad835e0362e8e --- /dev/null +++ b/tests/test-minja.cpp @@ -0,0 +1,434 @@ +/* + Minimalistic Jinja templating engine for llama.cpp. C++11, no deps (single-header), decent language support but very few functions (easy to extend), just what’s needed for actual prompt templates. + + Models have increasingly complex templates (e.g. Llama 3.1, Hermes 2 Pro w/ tool_use), so we need a proper template engine to get the best out of them. + + Supports: + - Full expression syntax + - Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}` + - `if` / `elif` / `else` / `endif` + - `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring + - `set` w/ namespaces & destructuring + - `macro` / `endmacro` + - Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject`, `tojson`, `trim` + + Limitations: + - Not supporting most filters & pipes. Only the ones actually used in the templates are implemented. + https://jinja.palletsprojects.com/en/3.0.x/templates/#builtin-filters + - No difference between none and undefined + - Single namespace with all filters / tests / functions / macros / variables + - No tuples (templates seem to rely on lists only) + - No `if` expressions w/o `else` (but `if` statements are fine) + - No `{% raw %}`, `{% block … %}`, `{% include … %}`, `{% extends … %}, + + Model templates verified to work: + - Meta-Llama-3.1-8B-Instruct + - Phi-3.5-mini-instruct + - Hermes-2-Pro-Llama-3-8B (default & tool_use variants) + - Qwen2-VL-7B-Instruct, Qwen2-7B-Instruct + - Mixtral-8x7B-Instruct-v0.1 + + TODO: + - Simplify two-pass parsing + - Pass tokens to IfNode and such + - Macro nested set scope = global? + {%- macro get_param_type(param) -%} + {%- set param_type = "any" -%} + - Advertise in / link to https://jbmoelker.github.io/jinja-compat-tests/ +*/ +#include "minja.hpp" + +#include +#include +#include +#include + +static std::string read_file(const std::string &path) { + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + +static std::vector find_files(const std::string & folder, const std::string & ext) { + std::vector files; + for (const auto & entry : std::__fs::filesystem::directory_iterator(folder)) { + if (entry.path().extension() == ext) + files.push_back(entry.path().string()); + } + return files; +} + +static std::string filename_without_extension(const std::string & path) { + auto res = path; + auto pos = res.find_last_of('/'); + if (pos != std::string::npos) + res = res.substr(pos + 1); + pos = res.find_last_of('.'); + if (pos != std::string::npos) + res = res.substr(0, pos); + return res; +} + +static void assert_equals(const std::string & expected, const std::string & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +static void announce_test(const std::string & name, const minja::Options & options) { + auto len = name.size(); + auto extract = minja::strip(name); + extract = json(name.substr(0, std::min(len, 50)) + (len > 50 ? " [...]" : "")).dump(); + extract = extract.substr(1, extract.size() - 2); + std::cout << "Testing: " << extract; + static const minja::Options default_options {}; + if (options.lstrip_blocks != default_options.lstrip_blocks) + std::cout << " lstrip_blocks=" << options.lstrip_blocks; + if (options.trim_blocks != default_options.trim_blocks) + std::cout << " trim_blocks=" << options.trim_blocks; + std::cout << std::endl << std::flush; +} + +static void test_render(const std::string & template_str, const json & bindings, const minja::Options & options, const std::string & expected, const json & expected_context = {}) { + announce_test(template_str, options); + auto root = minja::Parser::parse(template_str, options); + auto context = minja::Context::make(bindings); + std::string actual; + try { + actual = root->render(context); + } catch (const std::runtime_error & e) { + actual = "ERROR: " + std::string(e.what()); + } + + assert_equals(expected, actual); + + if (!expected_context.is_null()) { + // auto dump = context->dump(); + for (const auto & kv : expected_context.items()) { + auto value = context->get(kv.key()); + if (value != kv.value()) { + std::cerr << "Expected context value for " << kv.key() << ": " << kv.value() << std::endl; + std::cerr << "Actual value: " << value.dump() << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } + } + } + std::cout << "Test passed!" << std::endl << std::flush; +} + +static void test_error_contains(const std::string & template_str, const json & bindings, const minja::Options & options, const std::string & expected) { + announce_test(template_str, options); + try { + auto root = minja::Parser::parse(template_str, options); + auto context = minja::Context::make(bindings); + // auto copy = context.is_null() ? Value::object() : std::make_shared(context); + auto actual = root->render(context); + throw std::runtime_error("Expected error: " + expected + ", but got successful result instead: " + actual); + } catch (const std::runtime_error & e) { + std::string actual(e.what()); + if (actual.find(expected) == std::string::npos) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } + } + std::cout << " passed!" << std::endl << std::flush; +} + +static void test_template_features() { + test_render(R"({{ 'a' in {"a": 1} }},{{ 'a' in {} }})", {}, {}, "True,False"); + test_render(R"({{ 'a' in ["a"] }},{{ 'a' in [] }})", {}, {}, "True,False"); + test_render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) }})", {}, {}, R"([{'a': 1}])"); + test_render(R"({{ [{"a": 1}, {"a": 2}] | map(attribute="a") | list }})", {}, {}, "[1, 2]"); + test_render(R"({{ ["", "a"] | map("length") | list }})", {}, {}, "[0, 1]"); + test_render(R"({{ range(3) | last }})", {}, {}, "2"); + test_render(R"({% set foo = true %}{{ foo is defined }})", {}, {}, "True"); + test_render(R"({% set foo = true %}{{ not foo is defined }})", {}, {}, "False"); + test_render(R"({{ {"a": "b"} | tojson }})", {}, {}, R"({"a": "b"})"); + test_render(R"({{ {"a": "b"} }})", {}, {}, R"({'a': 'b'})"); + + std::string trim_tmpl = + "\n" + " {% if true %}Hello{% endif %} \n" + "...\n" + "\n"; + test_render( + trim_tmpl, + {}, { .trim_blocks = true }, "\n Hello...\n"); + test_render( + trim_tmpl, + {}, {}, "\n Hello \n...\n"); + test_render( + trim_tmpl, + {}, { .lstrip_blocks = true }, "\nHello \n...\n"); + test_render( + trim_tmpl, + {}, { .trim_blocks = true, .lstrip_blocks = true }, "\nHello...\n"); + + test_render( + R"({%- set separator = joiner(' | ') -%} + {%- for item in ["a", "b", "c"] %}{{ separator() }}{{ item }}{% endfor -%})", + {}, {}, "a | b | c"); + test_render("a\nb\n", {}, {}, "a\nb"); + test_render(" {{- ' a\n'}}", {}, {.trim_blocks = true}, " a\n"); + + test_render( + R"( + {%- for x in range(3) -%} + {%- if loop.first -%} + but first, mojitos! + {%- endif -%} + {{ loop.index }}{{ "," if not loop.last -}} + {%- endfor -%} + )", {}, {}, "but first, mojitos!1,2,3"); + test_render("{{ 'a' + [] | length + 'b' }}", {}, {}, "a0b"); + test_render("{{ [1, 2, 3] | join(', ') + '...' }}", {}, {}, "1, 2, 3..."); + test_render("{{ 'Tools: ' + [1, 2, 3] | reject('equalto', 2) | join(', ') + '...' }}", {}, {}, "Tools: 1, 3..."); + test_render("{{ [1, 2, 3] | join(', ') }}", {}, {}, "1, 2, 3"); + test_render("{% for i in range(3) %}{{i}},{% endfor %}", {}, {}, "0,1,2,"); + test_render("{% set foo %}Hello {{ 'there' }}{% endset %}{{ 1 ~ foo ~ 2 }}", {}, {}, "1Hello there2"); + test_render("{{ [1, False, null, True, 2, '3', 1, '3', False, null, True] | unique }}", {}, {}, + "[1, False, null, True, 2, '3']"); + test_render("{{ range(5) | length % 2 }}", {}, {}, "1"); + test_render("{{ range(5) | length % 2 == 1 }},{{ [] | length > 0 }}", {}, {}, "True,False"); + test_render( + "{{ messages[0]['role'] != 'system' }}", + {{"messages", json::array({json({{"role", "system"}})})}}, + {}, + "False"); + test_render( + R"( + {%- for x, y in [("a", "b"), ("c", "d")] -%} + {{- x }},{{ y -}}; + {%- endfor -%} + )", {}, {}, "a,b;c,d;"); + test_render("{{ 1 is not string }}", {}, {}, "True"); + test_render("{{ 'ab' * 3 }}", {}, {}, "ababab"); + test_render("{{ [1, 2, 3][-1] }}", {}, {}, "3"); + test_render( + "{%- for i in range(0) -%}NAH{% else %}OK{% endfor %}", + {}, {}, + "OK"); + test_render( + R"( + {%- for i in range(5) -%} + ({{ i }}, {{ loop.cycle('odd', 'even') }}), + {%- endfor -%} + )", {}, {}, "(0, odd),(1, even),(2, odd),(3, even),(4, odd),"); + + test_render( + "{%- for i in range(5) if i % 2 == 0 -%}\n" + "{{ i }}, first={{ loop.first }}, last={{ loop.last }}, index={{ loop.index }}, index0={{ loop.index0 }}, revindex={{ loop.revindex }}, revindex0={{ loop.revindex0 }}, prev={{ loop.previtem }}, next={{ loop.nextitem }},\n" + "{% endfor -%}", + {}, {}, + "0, first=True, last=False, index=1, index0=0, revindex=3, revindex0=2, prev=, next=2,\n" + "2, first=False, last=False, index=2, index0=1, revindex=2, revindex0=1, prev=0, next=4,\n" + "4, first=False, last=True, index=3, index0=2, revindex=1, revindex0=0, prev=2, next=,\n"); + + test_render( + R"( + {%- set res = [] -%} + {%- for c in ["<", ">", "&", '"'] -%} + {%- set _ = res.append(c | e) -%} + {%- endfor -%} + {{- res | join(", ") -}} + )", {}, {}, + R"(<, >, &, ")"); + test_render( + R"( + {%- set x = 1 -%} + {%- set y = 2 -%} + {%- macro foo(x, z, w=10) -%} + x={{ x }}, y={{ y }}, z={{ z }}, w={{ w -}} + {%- endmacro -%} + {{- foo(100, 3) -}} + )", {}, {}, + R"(x=100, y=2, z=3, w=10)"); + test_render( + R"( + {% macro input(name, value='', type='text', size=20) -%} + + {%- endmacro -%} + +

{{ input('username') }}

+

{{ input('password', type='password') }}

)", + {}, {}, R"( +

+

)"); + test_render( + R"( + {#- The values' default array should be created afresh at each call, unlike the equivalent Python function -#} + {%- macro foo(values=[]) -%} + {%- set _ = values.append(1) -%} + {{- values -}} + {%- endmacro -%} + {{- foo() }} {{ foo() -}})", + {}, {}, R"([1] [1])"); + test_render(R"({{ None | items | tojson }}; {{ {1: 2} | items | tojson }})", {}, {}, "[]; [[1, 2]]"); + test_render(R"({{ {1: 2, 3: 4, 5: 7} | dictsort | tojson }})", {}, {}, "[[1, 2], [3, 4], [5, 7]]"); + test_render(R"({{ {1: 2}.items() }})", {}, {}, "[[1, 2]]"); + test_render(R"({{ {1: 2}.get(1) }}; {{ {}.get(1) }}; {{ {}.get(1, 10) }})", {}, {}, "2; ; 10"); + test_render( + R"( + {%- for x in [1, 1.2, "a", true, True, false, False, None, [], [1], [1, 2], {}, {"a": 1}, {1: "b"}] -%} + {{- x | tojson -}}, + {%- endfor -%} + )", {}, {}, + R"(1,1.2,"a",True,True,False,False,null,[],[1],[1, 2],{},{"a": 1},{"1": "b"},)"); + test_render( + R"( + {%- set n = namespace(value=1, title='') -%} + {{- n.value }} "{{ n.title }}", + {%- set n.value = 2 -%} + {%- set n.title = 'Hello' -%} + {{- n.value }} "{{ n.title }}")", {}, {}, R"(1 "",2 "Hello")"); + test_error_contains( + "{{ (a.b.c) }}", + {{"a", json({{"b", {{"c", 3}}}})}}, + {}, + "'a' is not defined"); + test_render( + "{% set _ = a.b.append(c.d.e) %}{{ a.b }}", + json::parse(R"({ + "a": {"b": [1, 2]}, + "c": {"d": {"e": 3}} + })"), + {}, + "[1, 2, 3]"); + + test_render(R"( + {%- for x, y in z -%} + {{- x }},{{ y -}}; + {%- endfor -%} + )", {{"z", json({json({1, 10}), json({2, 20})})}}, {}, "1,10;2,20;"); + + test_render(" a {{ 'b' -}} c ", {}, {}, " a bc "); + test_render(" a {{- 'b' }} c ", {}, {}, " ab c "); + test_render("a\n{{- 'b' }}\nc", {}, {}, "ab\nc"); + test_render("a\n{{ 'b' -}}\nc", {}, {}, "a\nbc"); + + test_error_contains("{{ raise_exception('hey') }}", {}, {}, "hey"); + + test_render("{{ [] is iterable }}", {}, {}, "True"); + test_render("{{ [] is not number }}", {}, {}, "True"); + test_render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}, "[1, 2, 3][0, 1][1, 2]"); + test_render("{{ ' a ' | trim }}", {}, {}, "a"); + test_render("{{ range(3) }}{{ range(4, 7) }}{{ range(0, 10, step=2) }}", {}, {}, "[0, 1, 2][4, 5, 6][0, 2, 4, 6, 8]"); + + test_render( + R"( {{ "a" -}} b {{- "c" }} )", {}, {}, + " abc "); + + test_error_contains("{% else %}", {}, {}, "Unexpected else"); + test_error_contains("{% endif %}", {}, {}, "Unexpected endif"); + test_error_contains("{% elif 1 %}", {}, {}, "Unexpected elif"); + test_error_contains("{% endfor %}", {}, {}, "Unexpected endfor"); + + test_error_contains("{% if 1 %}", {}, {}, "Unterminated if"); + test_error_contains("{% for x in 1 %}", {}, {}, "Unterminated for"); + test_error_contains("{% if 1 %}{% else %}", {}, {}, "Unterminated if"); + test_error_contains("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}, "Unterminated if"); + + test_render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {}, ""); + + test_render( + "{% set x = [] %}{% set _ = x.append(1) %}{{ x | tojson(indent=2) }}", {}, {}, + "[\n 1\n]"); + + test_render( + "{{ not [] }}", {}, {}, + "True"); + + test_render("{{ tool.function.name == 'ipython' }}", + json({{"tool", json({ + {"function", {{"name", "ipython"}}} + })}}), + {}, + "True"); + + test_render(R"( + {%- set user = "Olivier" -%} + {%- set greeting = "Hello " ~ user -%} + {{- greeting -}} + )", {}, {}, "Hello Olivier"); +} + +static void test_chat_templates_with_common_contexts_against_goldens() { + auto jinja_template_files = find_files("tests/chat/templates", ".jinja"); + auto context_files = find_files("tests/chat/contexts", ".json"); + + auto get_golden_file = [&](const std::string & tmpl_file, const std::string & ctx_file) { + auto tmpl_name = filename_without_extension(tmpl_file); + auto ctx_name = filename_without_extension(ctx_file); + auto golden_name = tmpl_name + "-" + ctx_name; + return "tests/chat/goldens/" + golden_name + ".txt"; + }; + auto fail_with_golden_instructions = [&]() { + throw std::runtime_error("To fetch templates and generate golden files, run `python tests/update_jinja_goldens.py`"); + }; + if (jinja_template_files.empty()) { + std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl; + fail_with_golden_instructions(); + } + const auto options = minja::Options {.trim_blocks = true, .lstrip_blocks = true}; + for (const auto & tmpl_file : jinja_template_files) { + std::cout << "# Testing template: " << tmpl_file << std::endl << std::flush; + auto tmpl_str = read_file(tmpl_file); + auto tmpl = minja::Parser::parse(tmpl_str, options); + + auto found_goldens = false; + + for (const auto & ctx_file : context_files) { + auto ctx = json::parse(read_file(ctx_file)); + + auto golden_file = get_golden_file(tmpl_file, ctx_file); + if (!std::ifstream(golden_file).is_open()) { + continue; + } + found_goldens = true; + std::cout << " - " << golden_file << std::endl << std::flush; + + std::string actual; + try { + actual = tmpl->render(minja::Context::make(ctx)); + } catch (const std::runtime_error & e) { + actual = "ERROR: " + std::string(e.what()); + } + auto expected = read_file(golden_file); + assert_equals(expected, actual); + } + + if (!found_goldens) { + std::cerr << "No golden files found for " << tmpl_file << std::endl; + fail_with_golden_instructions(); + } + } +} + +/* + cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-minja -j && ./build/bin/test-minja +*/ +int main() { + test_template_features(); + + if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) { + fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m"); + } else { + test_chat_templates_with_common_contexts_against_goldens(); + } + + return 0; +} \ No newline at end of file diff --git a/tests/update_jinja_goldens.py b/tests/update_jinja_goldens.py new file mode 100644 index 0000000000000..bd547cd20d7d0 --- /dev/null +++ b/tests/update_jinja_goldens.py @@ -0,0 +1,141 @@ +#!/usr/bin/env uv run +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "jinja2", +# "huggingface_hub", +# ] +# /// +''' + Fetches the Jinja2 templates of a few known models and use them to generate prompt goldens for a few predefined chat contexts. + + Examples: + python ./tests/update_jinja_goldens.py + + https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py +''' + +import datetime +import glob +import os +from huggingface_hub import hf_hub_download +import json +import jinja2 +import jinja2.ext +import re +# import requests + +model_ids = [ + "NousResearch/Hermes-3-Llama-3.1-70B", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-2-Pro-Mistral-7B", + "meetkai/functionary-medium-v3.2", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-VL-7B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", # "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", + "Qwen/Qwen2.5-Math-7B-Instruct", # "Qwen/Qwen2.5-Math-72B-Instruct", + "microsoft/Phi-3.5-mini-instruct", + + # Gated models: + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "google/gemma-2-2b-it", + "mistralai/Mixtral-8x7B-Instruct-v0.1", +] + +def raise_exception(message: str): + raise ValueError(message) + +def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): + return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + +def strftime_now(format): + return datetime.now().strftime(format) + +def handle_chat_template(model_id, variant, template_src): + print(f"# {model_id} @ {variant}") + model_name = model_id.replace("/", "-") + base_name = f'{model_name}-{variant}' if variant else model_name + template_file = f'tests/chat/templates/{base_name}.jinja' + print(f'template_file: {template_file}') + with open(template_file, 'w') as f: + f.write(template_src) + + print(f"- {template_file}") + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + # keep_trailing_newline=False, + extensions=[ + jinja2.ext.loopcontrols + ]) + env.filters['tojson'] = tojson + env.globals['raise_exception'] = raise_exception + env.globals['strftime_now'] = strftime_now + + template_handles_tools = 'tools' in template_src + template_hates_the_system = 'System role not supported' in template_src + + template = env.from_string(template_src) + + context_files = glob.glob('tests/chat/contexts/*.json') + for context_file in context_files: + context_name = context_file.split("/")[-1].replace(".json", "") + with open(context_file, 'r') as f: + context = json.load(f) + + if not template_handles_tools and 'tools' in context: + continue + + if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): + continue + + output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt' + print(f"- {output_file}") + try: + output = template.render(**context) + except: + # Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message. + for message in context["messages"]: + if message.get("content") is None: + message["content"] = "" + + try: + output = template.render(**context) + except Exception as e: + print(f" ERROR: {e}") + output = f"ERROR: {e}" + + with open(output_file, 'w') as f: + f.write(output) + + print() + +def main(): + for dir in ['tests/chat/templates', 'tests/chat/goldens']: + if not os.path.isdir(dir): + os.mkdir(dir) + + for model_id in model_ids: + # response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + # response.raise_for_status() + # config_str = response.text + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + + try: + config = json.loads(config_str) + except json.JSONDecodeError as e: + # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json + # (Remove extra '}' near the end of the file) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + handle_chat_template(model_id, None, chat_template) + else: + for ct in chat_template: + handle_chat_template(model_id, ct['name'], ct['template']) + +if __name__ == '__main__': + main() \ No newline at end of file