From 0e70ba686e6c717a0aa41d88284e2a392c2bd0cd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Dec 2024 11:05:29 +0200 Subject: [PATCH] server : add "tokens" output (#10853) * server : add "tokens" output ggml-ci * server : update readme ggml-ci * server : return tokens ids only if requested ggml-ci * tests : improve "tokens" type check Co-authored-by: Xuan Son Nguyen * server : remove "tokens" from the OAI endpoint ggml-ci --------- Co-authored-by: Xuan Son Nguyen --- examples/server/README.md | 8 +++- examples/server/server.cpp | 38 ++++++++++++++----- examples/server/tests/unit/test_completion.py | 16 ++++++-- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 63a7bf43a920d..ecd24c899fc86 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -438,19 +438,22 @@ These words will not be included in the completion, so make sure to add them to `cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true` +`return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false` + `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values. `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` **Response format** -- Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. +- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. - `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: ```json { - "content": "", + "content": "", + "tokens": [ generated token ids if requested ], "probs": [ { "prob": float, @@ -468,6 +471,7 @@ These words will not be included in the completion, so make sure to add them to Notice that each `probs` is an array of length `n_probs`. - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. +- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `model`: The path to the model loaded with `-m` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71566b94e61bb..40aac33f0bf13 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -79,8 +79,9 @@ enum error_type { }; struct slot_params { - bool stream = true; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half @@ -199,6 +200,7 @@ struct server_task { params.stream = json_value(data, "stream", false); params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); params.n_indent = json_value(data, "n_indent", defaults.n_indent); params.n_keep = json_value(data, "n_keep", defaults.n_keep); @@ -468,7 +470,10 @@ struct completion_token_output { struct server_task_result_cmpl_final : server_task_result { int index = 0; - std::string content; + + std::string content; + llama_tokens tokens; + bool stream; result_timings timings; std::string prompt; @@ -510,6 +515,7 @@ struct server_task_result_cmpl_final : server_task_result { json res = json { {"index", index}, {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, {"id_slot", id_slot}, {"stop", true}, {"model", oaicompat_model}, @@ -539,9 +545,9 @@ struct server_task_result_cmpl_final : server_task_result { json choices = json::array({json{ {"finish_reason", finish_reason}, {"index", 0}, - {"message", json{ + {"message", json { {"content", content}, - {"role", "assistant"} + {"role", "assistant"} } }}}); @@ -605,7 +611,9 @@ struct server_task_result_cmpl_final : server_task_result { struct server_task_result_cmpl_partial : server_task_result { int index = 0; - std::string content; + + std::string content; + llama_tokens tokens; int32_t n_decoded; int32_t n_prompt_tokens; @@ -637,6 +645,7 @@ struct server_task_result_cmpl_partial : server_task_result { json res = json { {"index", index}, {"content", content}, + {"tokens", tokens}, {"stop", false}, {"id_slot", id_slot}, {"tokens_predicted", n_decoded}, @@ -678,7 +687,7 @@ struct server_task_result_cmpl_partial : server_task_result { json second_ret = json{ {"choices", json::array({json{{"finish_reason", nullptr}, {"index", 0}, - {"delta", json{ + {"delta", json { {"content", content}}} }})}, {"created", t}, @@ -693,7 +702,7 @@ struct server_task_result_cmpl_partial : server_task_result { {"finish_reason", nullptr}, {"index", 0}, {"delta", - json{ + json { {"content", content}, }}, }}); @@ -955,8 +964,11 @@ struct server_slot { size_t last_nl_pos = 0; - std::string generated_text; + std::string generated_text; + llama_tokens generated_tokens; + llama_tokens cache_tokens; + std::vector generated_token_probs; bool has_next_token = true; @@ -1000,6 +1012,7 @@ struct server_slot { n_sent_token_probs = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + generated_tokens.clear(); generated_token_probs.clear(); } @@ -1740,8 +1753,10 @@ struct server_context { const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); slot.sampled = result.tok; - // search stop word and delete it slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end @@ -1766,6 +1781,7 @@ struct server_context { break; } + // search stop word and delete it if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); @@ -1918,6 +1934,7 @@ struct server_context { res->id = slot.id_task; res->index = slot.index; res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens; @@ -1958,6 +1975,7 @@ struct server_context { res->index = slot.index; res->content = slot.generated_text; + res->tokens = slot.generated_tokens; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 062ebcd4a05cc..36aee57dd3638 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -10,16 +10,17 @@ def create_server(): global server server = ServerPreset.tinyllama2() -@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ - ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) -def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): +def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): global server server.start() res = server.make_request("POST", "/completion", data={ "n_predict": n_predict, "prompt": prompt, + "return_tokens": return_tokens, }) assert res.status_code == 200 assert res.body["timings"]["prompt_n"] == n_prompt @@ -27,6 +28,11 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, assert res.body["truncated"] == truncated assert type(res.body["has_new_line"]) == bool assert match_regex(re_content, res.body["content"]) + if return_tokens: + assert len(res.body["tokens"]) > 0 + assert all(type(tok) == int for tok in res.body["tokens"]) + else: + assert res.body["tokens"] == [] @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ @@ -56,6 +62,8 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp assert data["generation_settings"]["seed"] == server.seed assert match_regex(re_content, content) else: + assert len(data["tokens"]) > 0 + assert all(type(tok) == int for tok in data["tokens"]) content += data["content"]