Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool call support (Llama 3.x, Functionary v3, Hermes 2 Pro, Mistral Nemo, generic) w/ lazy grammars & minimalist Jinja engine #9639

Draft
wants to merge 182 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
182 commits
Select commit Hold shift + click to select a range
5b6d504
`grammar`: trigger words + refactor of antiprompts
ochafik Sep 25, 2024
eaca756
`minja`: minimalist Jinja templating engine for LLM chat templates
ochafik Sep 25, 2024
26c175b
`json`: build_grammar helper
ochafik Sep 25, 2024
3cfc21e
`tool-call`: basic Functionary 3.2, Llama 3.1, Hermes 2 Pro grammar g…
ochafik Sep 25, 2024
e309c6a
`tool-call`: integrate minja & tool-call to server when --jinja is set
ochafik Sep 25, 2024
41103c0
`server`: add --chat-template-file
ochafik Sep 25, 2024
4706bdb
`tool-call`: support Functionary v3 vs. v3-llama3.1 variants
ochafik Sep 25, 2024
8f25531
`tool-call`: add basic usage example to server readme
ochafik Sep 25, 2024
33ea20e
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Sep 25, 2024
d15dcfb
`tool-call`: add output example to readme
ochafik Sep 25, 2024
97d0620
`minja`: fetch more templates (add models from test-chat-template)
ochafik Sep 25, 2024
e983c9d
`tool-call`: fix llama_chat_apply_template signature / test-chat-temp…
ochafik Sep 25, 2024
45b243b
`minja`: fix llama_chat_apply_template + adde use_jinja param to vali…
ochafik Sep 26, 2024
9e366b3
`server`: fix tailing comma in completions_seed
ochafik Sep 26, 2024
a774093
`tool-call`: add server tests for llama 3.1
ochafik Sep 26, 2024
d928ff4
`server`: catch errors in oaicompat_completion_params_parse instead o…
ochafik Sep 26, 2024
ab25e3f
`tool-call`: allow empty message content when there's tool_calls in f…
ochafik Sep 26, 2024
1b62801
fix editorconfig lints
ochafik Sep 26, 2024
76d2938
fix flake8 lints
ochafik Sep 26, 2024
c124ab4
`minja`: add str.endswith
ochafik Sep 26, 2024
595e11c
`tool-call`: fix/test functionary v3
ochafik Sep 26, 2024
94377d7
`server`: catch errors in format_final_response_oaicompat instead of …
ochafik Sep 26, 2024
059babd
`minja`: try to please gcc
ochafik Sep 26, 2024
4cd82d6
`tool-call`: fix pyright type errors
ochafik Sep 26, 2024
2eb29bf
`tool-call`: update chat templates/goldens
ochafik Sep 26, 2024
5f5be9c
`minja`: gcc tweaks
ochafik Sep 26, 2024
8e4a9ba
`minja`: allow none input to selectattr, and add safe passthrough filter
ochafik Sep 26, 2024
0c87013
`tool-call`: test/fix functionary-medium-v3.1's template (can "look" …
ochafik Sep 26, 2024
749a21c
gcc appeasement
ochafik Sep 26, 2024
3d2650c
fix gcc build
ochafik Sep 26, 2024
d7ec84f
`tool-call`: allow <|python_tag|> in functionary-medium-3.1
ochafik Sep 26, 2024
cf7bece
`tool-call`: factor chat template away from legacy API
ochafik Sep 26, 2024
9cfe4d7
`tool-call`: refactor llama_chat_template class + use in validate_mod…
ochafik Sep 26, 2024
296331b
`minja`: update chat template goldens w/ llama.3.1 arguments workaround
ochafik Sep 26, 2024
50685f8
`minja`: add str.title()
ochafik Sep 26, 2024
5840e10
`tool-call`: merge & fix jinja template tests into test-chat-template
ochafik Sep 26, 2024
2926089
fix lints
ochafik Sep 26, 2024
c88c932
fix gcc error + lint
ochafik Sep 26, 2024
10f9fe8
`tool-call`: fix tool call return format
ochafik Sep 26, 2024
8299fac
`tool-call`: adapt very simple agent + docker isolation from https://…
ochafik Sep 26, 2024
f9c1743
`minja`: fix iterables
ochafik Sep 27, 2024
1e5c0e7
`chat-template`: fix jinja tests (make safe a passthrough)
ochafik Sep 27, 2024
9295ca9
`tool-call`: fix agent type lints
ochafik Sep 27, 2024
27cd07a
`json`: fix grammar conversion typo
ochafik Sep 27, 2024
6610ecf
`server`: rm bad debug code
ochafik Sep 27, 2024
0abfa36
`tool-call`: move usage examples to examples/agent
ochafik Sep 27, 2024
f62e688
`tool-call`: fix crash / test non-tool call case (added llama_sampler…
ochafik Sep 27, 2024
e33b342
`tool-call`: fix passing of tools to template + allow agent to finish
ochafik Sep 27, 2024
e62b5de
`tool-call`: fix functionary-small-3.2 (first tool starts w/ name\n, …
ochafik Sep 27, 2024
86e4f99
Update README.md
ochafik Sep 27, 2024
2f25ee3
Update README.md
ochafik Sep 27, 2024
0093a5e
`minja`: fix identifiers parsing (when start w/ not/is/etc) and lstri…
ochafik Sep 27, 2024
701b664
`minja`: add `indent` filter to support command-r-plus's chat templates
ochafik Sep 27, 2024
887951b
`minja`: generate chat goldens w/ fixed date to support Llama-3.2-3B-…
ochafik Sep 27, 2024
0c85bc7
`tool-call`: test tool call style detection
ochafik Sep 28, 2024
d983516
`tool-call`: let the tool call handler expand chat template, moving b…
ochafik Sep 28, 2024
8b2cf35
`tool-call`: fix grammar trigger crash
ochafik Sep 28, 2024
7cef90c
`tool-call`: more eager function call parsing for Functionary & Llama…
ochafik Sep 28, 2024
55cf337
`tool-call`: better error reporting for server tests
ochafik Sep 28, 2024
c657857
`tool-call`: cleanup tools.py
ochafik Sep 28, 2024
6e0053a
`chat-template`: enumerate files w/ C API rather than private using s…
ochafik Sep 28, 2024
05bbba9
`tool-call`: only match json eagerly for Llama 3.2
ochafik Sep 28, 2024
ef2a020
`tool-call`: make agent async
ochafik Sep 28, 2024
e6be59c
`antiprompts`: fix gcc8 build (avoid recursive struct)
ochafik Sep 28, 2024
9358d1f
`minja`: fix gcc8 build of test
ochafik Sep 28, 2024
1b32ac1
`chat-template`: fix test-arg
ochafik Sep 28, 2024
0ae1112
`agent`: try to fix pyright lint
ochafik Sep 28, 2024
dbda025
`tool-call`: test messages -> template -> grammar -> tool call parser
ochafik Sep 28, 2024
b10ef04
`chat-template`: tweak --chat-template error message when --jinja is set
ochafik Sep 28, 2024
bc3e0c0
`tool-call`: Qwen 2.5 Instruct also requires object arguments
ochafik Sep 28, 2024
a072f30
`tests`: attempt to find assets for tests run from build subfolder
ochafik Sep 28, 2024
ad6719e
`tests`: fix typo
ochafik Sep 28, 2024
22493c8
`tests`: fix test-chat-template run from build
ochafik Sep 28, 2024
c87c121
`tool-call`: fix memory leak in test
ochafik Sep 28, 2024
8738d94
`minja`: qualify std::nullptr_t type for msys2 build
ochafik Sep 28, 2024
cb7912e
`chat-template`: add phi-3.5-vision-instruct
ochafik Sep 28, 2024
9ac4b04
`tool-call`: add fs_list_files to common, w/ win32 impl for msys2 build
ochafik Sep 28, 2024
277f385
`minja`: attempt to handle windows' crlf
ochafik Sep 30, 2024
0fc5ad7
`minja`: avoid c++20 struct initializers in test
ochafik Sep 30, 2024
d9451fd
`antiprompts`: avoid c++20 struct initializers in test
ochafik Sep 30, 2024
c36a196
`tool-call`: prepare possible externalization of minja + factor tool …
ochafik Oct 1, 2024
c76b145
`tool-call`: fix Makefile
ochafik Oct 1, 2024
5b01402
`agent`: add brave_search & fetch_page tools + move to examples/agent…
ochafik Oct 2, 2024
f3538e7
update tools
ochafik Oct 2, 2024
9e502e8
`tool-call`: promote getting chat templates w/ dedicated script rathe…
ochafik Oct 2, 2024
b559d64
Update README.md
ochafik Oct 2, 2024
2428b73
`agent`: ditch openai dependency, use cache_prompt and expose seed
ochafik Oct 2, 2024
e2a9ab6
`agent`: --openai flag (auto-fetches OPENAI_API_KEY), improved logging
ochafik Oct 2, 2024
6f2191d
`agent`: remove *lots* of cruft from tool definitions derived from Fa…
ochafik Oct 2, 2024
26e76f9
`agent`: allow interactive chat by default, and don't reuse sessions
ochafik Oct 2, 2024
6b4a454
`agent`: hard-code max_results=10 in brave_search
ochafik Oct 2, 2024
fa8df0c
`agent`: drop fastify.py -> simpler serve_tools.py, and expose other …
ochafik Oct 2, 2024
ece12b0
`antiprompts`: ensure partial match is at end of string (or else serv…
ochafik Oct 3, 2024
b4fc1e8
`tool-call`: adjust triggers to most common tool call variations from…
ochafik Oct 3, 2024
da02397
`agent`: support more providers (+ extract serve_tools_inside_docker.sh)
ochafik Oct 3, 2024
366efc8
`tool-call`: fix llama 3.x tc parsing when there are spaces before "n…
ochafik Oct 3, 2024
21a3c90
`agent`: tool tweaks (remove ansi escapes from python output, update …
ochafik Oct 3, 2024
a151ddc
`agent`: handle function errors and dont' stringify str outputs
ochafik Oct 4, 2024
241acc2
`agent`: disable brave_search when BRAVE_SEARCH_API_KEY unset
ochafik Oct 7, 2024
3325069
`tool-call`: accept `{"type": "function", "name": "fn"` for llama 3.x
ochafik Oct 7, 2024
e753f15
`agent`: move openapi helpers to their own file
ochafik Oct 8, 2024
7576487
`tool-call`: fix grammar roots
ochafik Oct 22, 2024
fa8462f
fix root
ochafik Oct 22, 2024
9f5ab97
`tool-calls`: add generic tool call style as default
ochafik Oct 22, 2024
b53362a
Update test-tool-call.cpp
ochafik Oct 22, 2024
7f2429e
`tool-calls`: fix grammar regression
ochafik Oct 22, 2024
db4bf93
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Oct 22, 2024
351aecb
Update llama-sampling.cpp
ochafik Oct 22, 2024
a4f12a4
`minja`: fix string subscripts, add string pipe to support Mistral-Ne…
ochafik Oct 22, 2024
fc80ad2
`tool-call`: Log tool call style name, ensure returned content not null
ochafik Oct 22, 2024
3e12b9b
`tool-calls`: basic Nemo support, default parallel to true if templat…
ochafik Oct 23, 2024
2b49440
`tool-call`: fix previous commit's parallel arg
ochafik Oct 23, 2024
5f4aef1
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Oct 23, 2024
4394e1c
Update tool-call.cpp
ochafik Oct 23, 2024
414f6f1
Merge branch 'tool-call' of github.com:ochafik/llama.cpp into tool-call
ochafik Oct 23, 2024
267e630
`agent`: isolate tools container + log its outgoing HTTP & HTTPS traf…
ochafik Oct 24, 2024
f5320af
`tool-call`: return tool_call.id (required by Nemo)
ochafik Oct 24, 2024
0f5d639
`agent`: display http errors nicely
ochafik Oct 24, 2024
d338bfb
`agent`: ditch aiohttp & define REQUESTS_CA_BUNDLE to fix http proxyi…
ochafik Oct 24, 2024
c2926e4
Update README.md
ochafik Oct 24, 2024
03b8641
`agent`: fix deps + make docker compose setup easier to debug
ochafik Oct 24, 2024
0f4fc8c
`agent`: fix no-cache issue in squid for brave tool
ochafik Oct 24, 2024
5c414a3
`agent`: simplify tools setup
ochafik Oct 25, 2024
30bd00b
`agent`: fix tools setup
ochafik Oct 25, 2024
080982e
`tool-call`: test MistralNemo in forced tools server tests (w/ parall…
ochafik Oct 27, 2024
ec9f3b1
nits
ochafik Oct 27, 2024
9a86ea7
`tool-call`: slow tool call integration tests
ochafik Oct 28, 2024
c88095e
space nits
ochafik Oct 28, 2024
7fde6d0
`tool_call`: test no tool call on a real model + rename scenarios
ochafik Oct 28, 2024
dd6d024
`tool-call`: script to prefetch models used in server tests
ochafik Oct 28, 2024
168add7
Update tool_call.feature
ochafik Oct 28, 2024
ec547e4
`tool-call`: add tests: tool_call=none, parallel_tool_calls=true
ochafik Oct 28, 2024
b51c71c
`tool-call`: remove duplicate script to fetch templates
ochafik Oct 28, 2024
74d71a6
`agent`: simplify syntax (default tools to local w/ default port)
ochafik Oct 28, 2024
b825440
`tool-call`: use Q4_K_M models
ochafik Oct 28, 2024
aefac1e
`tool-call`: update scripts/fetch_server_test_models.py
ochafik Oct 28, 2024
64287a3
`tool-call`: test Hermes-3-Llama-3.1-8B
ochafik Oct 29, 2024
fa4c111
`tool-call`: use functionary-small-v3.2-Q8_0.gguf in test (Q4_K_M too…
ochafik Oct 29, 2024
773ff91
`tool-call`: force printing of lazy grammar trigger tokens to regular…
ochafik Oct 29, 2024
92c384a
nits
ochafik Oct 29, 2024
3ebdb2b
`tool-call`: support tool_use variant in llama_chat_template_from_mod…
ochafik Oct 30, 2024
35ac17f
`tool-call`: fix missing initializer errors
ochafik Oct 30, 2024
5227321
`tool-call`: when slow server tests fail, hint to run `python scripts…
ochafik Oct 30, 2024
e4d5449
`tool-calls`: test Qwen2.5-7B-Instruct-Q4_K_M.gguf
ochafik Oct 30, 2024
61655b9
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Oct 31, 2024
be9de3e
Update llama-sampling.cpp
ochafik Oct 31, 2024
542853b
`tool-call`: greedy sampling in server tests + tweak prompt
ochafik Oct 31, 2024
7d9c90f
`tool-call`: nemo tweak (accept raw sql again)
ochafik Oct 31, 2024
e8d9d71
Update tool_call.feature
ochafik Oct 31, 2024
c395d48
`tool-call`: behaviour-based detection of template features
ochafik Oct 31, 2024
f5b7825
`tool-call`: code_interpreter & system + tool call support for all ji…
ochafik Oct 31, 2024
c773516
`tool-call`: don't use -fa w/ Mistral-Nemo (hard crashes?)
ochafik Oct 31, 2024
b35aa4a
`tool-call`: add LLAMA_UPDATE_GOLDENS env for test-chat-template
ochafik Oct 31, 2024
9477c54
`tool-call`: functionary-small-v3.2 test now green
ochafik Oct 31, 2024
c4a8050
Update README.md
ochafik Oct 31, 2024
f5f7475
nits
ochafik Oct 31, 2024
fe967b6
Update README.md
ochafik Oct 31, 2024
479c152
`tool-call`: fix qwen template test
ochafik Oct 31, 2024
bc52c0a
`agent`: add missing tool name in response!
ochafik Oct 31, 2024
c059aec
`agent`: memorize, search_memory (sqlite-vec + sqlite-lembed), fetch …
ochafik Nov 9, 2024
5789f69
`minja`: don't explode upon referencing a field on an array (fixes He…
ochafik Nov 9, 2024
f9b1969
Update README.md
ochafik Nov 9, 2024
adc673c
agent: add --think "tool", default to local tools endpoint, support -…
ochafik Dec 5, 2024
1afa312
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Dec 6, 2024
30fbcb2
agent: more robust squid config
ochafik Dec 6, 2024
a469f53
agent: update readme
ochafik Dec 6, 2024
cbe395d
minja: remove tests (now in https://github.com/google/minja)
ochafik Dec 6, 2024
1fd5f1a
Update README.md
ochafik Dec 6, 2024
5d0033f
minja: sync @ https://github.com/google/minja/commit/916c181c0d4a6f96…
ochafik Dec 7, 2024
1f0b157
tool-call: add firefunction-v2 style
ochafik Dec 7, 2024
93a5245
tool-calls: migrate tests to pytest
ochafik Dec 10, 2024
055053c
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Dec 14, 2024
1e2115f
tool-calls: shorter name: grammar_triggers
ochafik Dec 14, 2024
7bfcd0a
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Dec 14, 2024
7e3feff
tool-call: stabilize server tests
ochafik Dec 15, 2024
e70ce3f
Merge remote-tracking branch 'origin/master' into tool-call
ochafik Dec 26, 2024
f0bd693
Update test-tool-call.cpp
ochafik Dec 26, 2024
f645887
Update minja.hpp https://github.com/google/minja/commit/202aa2f3de21b…
ochafik Dec 26, 2024
0e87ae2
rm trailing spaces
ochafik Dec 27, 2024
0a5d527
Update fetch_server_test_models.py
ochafik Dec 27, 2024
a2fe8a4
Fix tool-call server tests
ochafik Dec 27, 2024
523ebf8
Simplify tool call grammars when there's only 1 tool
ochafik Dec 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ indent_style = tab
[examples/cvector-generator/*.txt]
trim_trailing_whitespace = unset
insert_final_newline = unset

[{tests/chat/templates/*.jinja,tests/chat/goldens/*.txt}]
indent_style = unset
indent_size = unset
end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset
22 changes: 22 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ BUILD_TARGETS = \

# Binaries only useful for tests
TEST_TARGETS = \
tests/test-antiprompts \
tests/test-arg-parser \
tests/test-autorelease \
tests/test-backend-ops \
Expand All @@ -57,13 +58,15 @@ TEST_TARGETS = \
tests/test-grammar-integration \
tests/test-grammar-parser \
tests/test-json-schema-to-grammar \
tests/test-minja \
tests/test-llama-grammar \
tests/test-log \
tests/test-model-load-cancel \
tests/test-quantize-fns \
tests/test-quantize-perf \
tests/test-rope \
tests/test-sampling \
tests/test-tool-call \
tests/test-tokenizer-0 \
tests/test-tokenizer-1-bpe \
tests/test-tokenizer-1-spm
Expand Down Expand Up @@ -984,6 +987,7 @@ OBJ_COMMON = \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/tool-call.o \
$(DIR_COMMON)/json-schema-to-grammar.o

OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON)
Expand Down Expand Up @@ -1361,7 +1365,10 @@ llama-server: \
examples/server/httplib.h \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
common/tool-call.h \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
Expand Down Expand Up @@ -1469,6 +1476,21 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-antiprompts: tests/test-antiprompts.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-tool-call: tests/test-tool-call.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-minja: tests/test-minja.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-opt: tests/test-opt.cpp \
$(OBJ_GGML)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
Expand Down
3 changes: 3 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.hpp
common.cpp
common.h
console.cpp
Expand All @@ -64,12 +65,14 @@ add_library(${TARGET} STATIC
json.hpp
log.cpp
log.h
minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
tool-call.cpp
)

if (BUILD_SHARED_LIBS)
Expand Down
43 changes: 39 additions & 4 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
if (!common_chat_verify_template(value)) {
if (!common_chat_verify_template(value, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
value.c_str()
"error: the supplied chat template is not supported: %s%s\n",
value.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
));
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
"set custom jinja chat template file (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::string chat_template;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(chat_template)
);
if (!common_chat_verify_template(chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
value.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
));
}
params.chat_template = chat_template;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
Expand Down
209 changes: 209 additions & 0 deletions common/chat-template.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
Copyright 2024 Google LLC

Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once

#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>

using json = nlohmann::ordered_json;

namespace minja {

class chat_template {
public:

private:
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
bool _supports_parallel_tool_calls = false;
std::string _source;
std::string _bos_token;
std::string _eos_token;
std::shared_ptr<minja::TemplateNode> _template_root;

bool renders_needles(
const std::vector<std::string> & needles,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
for (const auto & needle : needles) {
if (prompt.find(needle) == std::string::npos) {
return false;
}
}
return true;
} catch (const std::exception & e) {
return false;
}
}

public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
{
_template_root = minja::Parser::parse(_source, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
_supports_tools = source.find("tools") != std::string::npos;
_requires_object_arguments =
source.find("tool_call.arguments | items") != std::string::npos
|| source.find("tool_call.arguments | tojson") != std::string::npos;
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;

_supports_system_role = renders_needles({"<System Needle>"}, {
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false);
}

const std::string & source() const { return _source; }
bool supports_tools() const { return _supports_tools; }
bool supports_parallel_tool_calls() const { return _supports_parallel_tool_calls; }

std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
json actual_messages;

// First, "fix" messages so they have a chance to be rendered correctly by the template

if (_requires_object_arguments || !_supports_system_role || !_supports_tools) {
actual_messages = json::array();

std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");

if (message.contains("tool_calls")) {
if (_requires_object_arguments || !_supports_tools) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
if (!_supports_tools) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && content != "") {
obj["content"] = content;
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (!_supports_tools && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", {
{"tool", message.at("name")},
{"content", message.at("content")},
}},
};
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}

// std::string content = message["content"];
if (!message["content"].is_null() && !_supports_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
actual_messages.push_back(message);
}
flush_sys();
} else {
actual_messages = messages;
}

auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));

if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}

return _template_root->render(context);
}
};

} // namespace minja
Loading
Loading