From 171e3c1f8aabde1bebff581e8d17763e31b35dc9 Mon Sep 17 00:00:00 2001 From: wangyumu Date: Sun, 5 Nov 2023 17:18:05 +0800 Subject: [PATCH] Fixes #360 add stop_token_ids Signed-off-by: wangyumu --- include/fastllm.h | 2 +- src/models/basellm.cpp | 11 ++++++++ src/models/llama.cpp | 6 +++++ tools/fastllm_pytools/llm.py | 51 +++++++++++++++++++++++++----------- tools/src/pytools.cpp | 14 ++++++++-- 5 files changed, 66 insertions(+), 18 deletions(-) diff --git a/include/fastllm.h b/include/fastllm.h index 22fde50f..2010e728 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -44,7 +44,7 @@ namespace fastllm { float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性 bool output_logits = false; // 是否返回logits bool enable_hash_id = false; // 给会话添加hash id - + std::multiset stop_token_ids; bool IsSimpleGreedy() const { if (fabs(repeat_penalty - 1) > 1e-8) { diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 18af45b5..3b67c9d2 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -208,6 +208,11 @@ namespace fastllm { inputTokens[i] = std::vector {(float)ret[i]}; if (ret[i] == eos_token_id) { isEnding[i] = true; + } else { + auto itStopTk = generationConfig.stop_token_ids.find(ret[i]); + if (itStopTk != generationConfig.stop_token_ids.end()) { + isEnding[i] = true; + } } if (isEnding[i]) { curStrings.push_back(""); @@ -659,6 +664,12 @@ printf("tot = %d\n", tot); if (curRet == model->eos_token_id) { it.second->isEnding = true; } else { + auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet); + if (itStopTk != it.second->generationConfig.stop_token_ids.end()) { + it.second->isEnding = true; + } + } + if (it.second->isEnding == false) { it.second->currentTokens = std::vector{curRet}; it.second->resultTokenQueue.push(curRet); it.second->tokens.Push(curRet); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index c3446225..aca39854 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -990,6 +990,12 @@ namespace fastllm { if (curRet == model->eos_token_id) { it.second->isEnding = true; } else { + auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet); + if (itStopTk != it.second->generationConfig.stop_token_ids.end()) { + it.second->isEnding = true; + } + } + if (it.second->isEnding == false) { it.second->currentTokens = std::vector{curRet}; it.second->resultTokenQueue.push(curRet); it.second->tokens.Push(curRet); diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index ca42e516..82428e55 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -21,7 +21,8 @@ fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, - ctypes.c_float, ctypes.c_float, ctypes.c_bool] + ctypes.c_float, ctypes.c_float, ctypes.c_bool, + ctypes.c_int, ctypes.POINTER(ctypes.c_int)] fastllm_lib.launch_response_llm_model.restype = ctypes.c_int fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] @@ -37,7 +38,8 @@ fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, - ctypes.c_float, ctypes.c_float, ctypes.c_bool] + ctypes.c_float, ctypes.c_float, ctypes.c_bool, + ctypes.c_int, ctypes.POINTER(ctypes.c_int)] fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] @@ -199,19 +201,29 @@ def tokenizer_decode_token(self, token_id: int) -> bytes: break return buffer_bytes[:result_len] + def stop_token_ctypes(self, stop_token_ids): + if stop_token_ids is None: + return 0, None + else: + return ctypes.c_int(len(stop_token_ids)), (ctypes.c_int * len(stop_token_ids))(*stop_token_ids) + def response_logits(self, query: str, history: List[Tuple[str, str]] = None, - tokenizer = None) -> str: + tokenizer = None, + stop_token_ids: List[int] = None, + ) -> str: prompt = query if self.direct_query else self.get_prompt(query, history); + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) if (tokenizer == None): handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1), - ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True)); + ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True), + stop_token_len, stop_token_list); else: input = tokenizer.encode(prompt); handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), - 1, False, 1, 1, 1, 1, True); + 1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list); vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model); logits = list(range(vocab_size)) array = (ctypes.c_float * (vocab_size * 4))(*logits); @@ -224,7 +236,8 @@ def response_logits(self, def response(self, query: str, history: List[Tuple[str, str]] = None, - max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str: + max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, + stop_token_ids: List[int] = None) -> str: ret = ""; for i in self.stream_response(query = query, history = history, @@ -233,7 +246,8 @@ def response(self, top_p = top_p, top_k = top_k, temperature = temperature, repeat_penalty = repeat_penalty, - one_by_one = True): + one_by_one = True, + stop_token_ids = stop_token_ids): ret += i; return ret; @@ -241,11 +255,13 @@ def stream_response(self, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, - one_by_one = True): + one_by_one = True, stop_token_ids: List[int] = None): prompt = query if self.direct_query else self.get_prompt(query, history); + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids); handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), - ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False)); + ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False), + stop_token_len, stop_token_list); res = ""; ret = b''; fail_cnt = 0; @@ -273,12 +289,15 @@ def stream_response(self, def stream_response_raw(self, input_tokens: List[int], max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, - one_by_one = True + one_by_one = True, + stop_token_ids: List[int] = None ): + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens), (ctypes.c_int * len(input_tokens))(*input_tokens), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), - ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False)) + ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False), + stop_token_len, stop_token_list) # 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部 # 方便统计输出token数量,和控制不完整utf8时候解码的逻辑 @@ -298,14 +317,15 @@ def stream_response_raw(self, yield total_bytes def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, - do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs): + do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, stop_token_ids: List[int] = None, **kwargs): if (not(history)): history = []; prompt = query if self.direct_query else self.get_prompt(query, history); input = tokenizer.encode(prompt); + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), max_length, do_sample, top_p, top_k, temperature, repeat_penalty, - False); + False, stop_token_len, stop_token_list); result = []; while True: @@ -319,14 +339,15 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, - return_past_key_values = False, **kwargs) -> str: + return_past_key_values = False, stop_token_ids: List[int] = None, **kwargs) -> str: if (not(history)): history = []; prompt = query if self.direct_query else self.get_prompt(query, history); input = tokenizer.encode(prompt); + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), max_length, do_sample, top_p, top_k, temperature, repeat_penalty, - False); + False, stop_token_len, stop_token_list); tokens = []; while True: cur = fastllm_lib.fetch_response_llm_model(self.model, handle); diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index 933c0bf6..7e4f4b41 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -241,7 +241,8 @@ extern "C" { DLL_EXPORT int launch_response_str_llm_model(int modelId, char *content, int max_length, bool do_sample, float top_p, int top_k, - float temperature, float repeat_penalty, bool output_logits) { + float temperature, float repeat_penalty, bool output_logits, + int stop_token_len, int * stop_token_ids) { auto model = models.GetModel(modelId); std::vector tokens; auto v = model->weight.tokenizer.Encode(content); @@ -249,6 +250,10 @@ extern "C" { tokens.push_back((int)((float*)v.cpuData)[i]); } auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits); + for(int i = 0; i < stop_token_len; i++ ) + { + config.stop_token_ids.insert(stop_token_ids[i]); + } return model->LaunchResponseTokens(tokens, config); } @@ -261,12 +266,17 @@ extern "C" { DLL_EXPORT int launch_response_llm_model(int modelId, int len, int *values, int max_length, bool do_sample, float top_p, int top_k, - float temperature, float repeat_penalty, bool output_logits) { + float temperature, float repeat_penalty, bool output_logits, + int stop_token_len, int * stop_token_ids) { std::vector input; for (int i = 0; i < len; i++) { input.push_back(values[i]); } auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits); + for(int i = 0; i < stop_token_len; i++ ) + { + config.stop_token_ids.insert(stop_token_ids[i]); + } auto model = models.GetModel(modelId); return model->LaunchResponseTokens(input, config); }