Skip to content

Commit

Permalink
tool-call: integrate minja & tool-call to server when --jinja is set
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Sep 25, 2024
1 parent 3cfc21e commit e309c6a
Show file tree
Hide file tree
Showing 10 changed files with 514 additions and 163 deletions.
12 changes: 10 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1844,13 +1844,21 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](gpt_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
[](gpt_params & params, const std::string & value) {
if (!llama_chat_verify_template(value)) {
if (!llama_chat_verify_template(value, params.use_jinja)) {
throw std::runtime_error(format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
Expand Down
26 changes: 17 additions & 9 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,16 +1510,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
// Chat template utils
//

bool llama_chat_verify_template(const std::string & tmpl) {
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0, use_jinja);
return res >= 0;
}

std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass) {
bool add_ass,
bool use_jinja,
const std::string & tools,
const char * bos_token,
const char * eos_token) {
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
Expand All @@ -1532,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
std::vector<char> buf(alloc_size);

// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token);

// error: chat template is not supported
if (res < 0) {
Expand All @@ -1542,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
fallback = true;
}
}
Expand All @@ -1553,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
}

std::string formatted_chat(buf.data(), res);
Expand All @@ -1564,17 +1568,21 @@ std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass) {
bool add_ass,
bool use_jinja,
const std::string & tools,
const char * bos_token,
const char * eos_token) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token);
std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
Expand Down
23 changes: 19 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ struct gpt_params {
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;

std::vector<std::string> api_keys;
Expand Down Expand Up @@ -469,29 +470,43 @@ std::string llama_detokenize(
// Chat template utils
//

// same with llama_chat_message, but uses std::string
// same as llama_chat_message, but uses std::string and std::vector
struct llama_chat_msg {
std::string role;
std::string content;
std::string tool;
struct llama_tool_call {
std::string name;
std::string arguments;
};
std::vector<llama_tool_call> tool_calls;
};

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl);
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);

// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & chat,
bool add_ass);
bool add_ass,
bool use_jinja = false,
const std::string & tools = "",
const char * bos_token = nullptr,
const char * eos_token = nullptr);

// Format single message, while taking into account the position of that message in chat history
std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass);
bool add_ass,
bool use_jinja = false,
const std::string & tools = "",
const char * bos_token = nullptr,
const char * eos_token = nullptr);

// Returns an example of formatted chat
std::string llama_chat_format_example(const struct llama_model * model,
Expand Down
4 changes: 3 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2781,6 +2781,8 @@ int main(int argc, char ** argv) {
{ "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel },
{ "bos_token", llama_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) },
{ "eos_token", llama_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) },
{ "chat_template", curr_tmpl.c_str() },
};

Expand Down Expand Up @@ -2854,7 +2856,7 @@ int main(int argc, char ** argv) {
return;
}

json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja);

std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
ctx_server.queue_results.add_waiting_tasks(tasks);
Expand Down
43 changes: 35 additions & 8 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.server_seed = None
context.user_api_key = None
context.response_format = None
context.tools = None
context.tool_choice = None
context.temperature = None
context.lora_file = None
context.disable_ctx_shift = False
Expand Down Expand Up @@ -363,6 +365,13 @@ def step_max_tokens(context, max_tokens):
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)

@step('tools {tools}')
def step_tools(context, tools):
context.tools = json.loads(tools)

@step('tool choice {tool_choice}')
def step_tool_choice(context, tool_choice):
context.tool_choice = tool_choice

@step('{temperature:f} temperature')
def step_temperature(context, temperature):
Expand Down Expand Up @@ -497,6 +506,11 @@ async def step_oai_chat_completions(context, api_error):
response_format=context.response_format
if hasattr(context, 'response_format') else None,

tools=context.tools
if hasattr(context, 'tools') else None,

tool_choice=context.tool_choice,

user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,

Expand Down Expand Up @@ -567,6 +581,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
tools=context.tools
if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)

Expand All @@ -580,16 +597,18 @@ async def step_oai_chat_completions(context):
context.base_url,
'/chat/completions',
True, # async_client
model=context.model
if hasattr(context, 'model') else None,
n_predict=context.n_predict
if hasattr(context, 'n_predict') else None,
model=context.model,
# if hasattr(context, 'model') else None,
n_predict=context.n_predict,
# if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
response_format=context.response_format,
# if hasattr(context, 'response_format') else None,
tools=context.tools,# if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
user_api_key=context.user_api_key)
# if hasattr(context, 'user_api_key') else None)


@step('all prompts are predicted')
Expand Down Expand Up @@ -974,6 +993,8 @@ async def oai_chat_completions(user_prompt,
n_predict=None,
enable_streaming=None,
response_format=None,
tools=None,
tool_choice=None,
user_api_key=None,
expect_api_error=None) -> int | dict[str, Any]:
if debug:
Expand Down Expand Up @@ -1001,6 +1022,10 @@ async def oai_chat_completions(user_prompt,
}
if response_format is not None:
payload['response_format'] = response_format
if tools is not None:
payload['tools'] = tools
if tool_choice is not None:
payload['tool_choice'] = tool_choice
completion_response = {
'content': '',
'timings': {
Expand Down Expand Up @@ -1065,6 +1090,8 @@ async def oai_chat_completions(user_prompt,
max_tokens=n_predict,
stream=enable_streaming,
response_format=payload.get('response_format') or openai.NOT_GIVEN,
tools=payload.get('tools'),
tool_choice=payload.get('tool_choice'),
seed=seed,
temperature=payload['temperature']
)
Expand Down
Loading

0 comments on commit e309c6a

Please sign in to comment.