From 2a0e9d03b8496b26836446546c7c0f48b40c3cdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 18 Jul 2024 11:04:30 +0800 Subject: [PATCH] =?UTF-8?q?ftllm.server=E5=92=8Cftllm.webui=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0system=5Fprompt=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/fastllm_pytools/llm.py | 149 ++++++++++++------ .../openai_server/fastllm_completion.py | 36 +---- tools/fastllm_pytools/web_demo.py | 13 +- 3 files changed, 121 insertions(+), 77 deletions(-) diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 22dcac7b..f87cf272 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -415,6 +415,31 @@ def __init__ (self, path : str, # 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。 # 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间 self.tokenizer_decode_token_cache = None + + def apply_chat_template( + self, + conversation: List[Dict[str, str]], + chat_template: Optional[str] = None, + add_generation_prompt: bool = False, + **kwargs, + ) -> str: + messages = [] + for it in conversation: + if it["role"] == "system": + messages += ["system", it["content"]] + for it in conversation: + if it["role"] != "system": + messages += [it["role"], it["content"]] + poss = [] + lens = [] + all = b'' + for i in range(len(messages)): + messages[i] = messages[i].encode() + all += messages[i] + poss.append(0 if i == 0 else poss[-1] + lens[-1]) + lens.append(len(messages[i])) + str = fastllm_lib.apply_chat_template(self.model, all, len(messages), (ctypes.c_int * len(poss))(*poss), (ctypes.c_int * len(lens))(*lens)).decode() + return str def generate( self, @@ -611,32 +636,57 @@ def response(self, return ret; def stream_response(self, - query: str, + query: Union[str, List[Dict[str, str]]], history: List[Tuple[str, str]] = None, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, - one_by_one = True, stop_token_ids: List[int] = None): + one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True): + conversation = None + if (isinstance(query, List)): + conversation = query if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""): - lastlen = 0 - for cur in self.stream_chat(tokenizer = self.hf_tokenizer, - query = query, - history = history, - max_length = max_length, - do_sample = do_sample, - top_p = top_p, top_k = top_k, - temperature = temperature, - repeat_penalty = repeat_penalty, - stop_token_ids = stop_token_ids): - if one_by_one: - ret = cur[0][lastlen:] - if (ret.encode().find(b'\xef\xbf\xbd') == -1): - lastlen = len(cur[0]) - yield ret - else: - yield "" + tokenizer = self.hf_tokenizer + type = None + if (hasattr(tokenizer, "name") + and tokenizer.name == "GLMTokenizer" + and hasattr(tokenizer, "build_chat_input")): + type = "ChatGLM3" + if (not(history)): + history = []; + if (type == "ChatGLM3"): + input = tokenizer.build_chat_input(query, history=history)["input_ids"].reshape(-1).tolist() + else: + prompt = "" + if (conversation != None and len(conversation) != 0): + prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt = add_generation_prompt, tokenize = False) + else: + prompt = query if self.direct_query else self.get_prompt(query, history) + input = tokenizer.encode(prompt) + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) + handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), + max_length, do_sample, top_p, top_k, temperature, repeat_penalty, + False, stop_token_len, stop_token_list) + tokens = []; + while True: + if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)): + continue + cur = fastllm_lib.fetch_response_llm_model(self.model, handle) + if (cur == -1): + break + tokens.append(cur) + ret = tokenizer.decode(tokens) + if (ret.encode().find(b'\xef\xbf\xbd') == -1): + tokens.clear() + yield ret else: - yield cur[0] + yield "" + if len(tokens) > 0: + yield tokenizer.decode(tokens) else: - prompt = query if self.direct_query else self.get_prompt(query, history); + prompt = "" + if (conversation != None and len(conversation) != 0): + prompt = self.apply_chat_template(conversation) + else: + prompt = query if self.direct_query else self.get_prompt(query, history) stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids); handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), @@ -677,10 +727,13 @@ def add_cache(self, exit(0) async def stream_response_async(self, - query: str, + query: Union[str, List[Dict[str, str]]], history: List[Tuple[str, str]] = None, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, - one_by_one = True, stop_token_ids: List[int] = None): + one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True): + conversation = None + if (isinstance(query, List)): + conversation = query if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""): tokenizer = self.hf_tokenizer type = None @@ -693,12 +746,16 @@ async def stream_response_async(self, if (type == "ChatGLM3"): input = tokenizer.build_chat_input(query, history=history)["input_ids"].reshape(-1).tolist() else: - prompt = query if self.direct_query else self.get_prompt(query, history); - input = tokenizer.encode(prompt); + prompt = "" + if (conversation != None and len(conversation) != 0): + prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt = add_generation_prompt, tokenize = False) + else: + prompt = query if self.direct_query else self.get_prompt(query, history) + input = tokenizer.encode(prompt) stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), max_length, do_sample, top_p, top_k, temperature, repeat_penalty, - False, stop_token_len, stop_token_list); + False, stop_token_len, stop_token_list) tokens = []; while True: if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)): @@ -717,38 +774,42 @@ async def stream_response_async(self, if len(tokens) > 0: yield tokenizer.decode(tokens) else: - prompt = query if self.direct_query else self.get_prompt(query, history); - stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids); + prompt = "" + if (conversation != None and len(conversation) != 0): + prompt = self.apply_chat_template(conversation) + else: + prompt = query if self.direct_query else self.get_prompt(query, history) + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False), - stop_token_len, stop_token_list); - res = ""; - ret = b''; - fail_cnt = 0; + stop_token_len, stop_token_list) + res = "" + ret = b'' + fail_cnt = 0 while True: if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)): await asyncio.sleep(0) continue - ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle); - cur = ""; + ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle) + cur = "" try: - cur = ret.decode(); - ret = b''; + cur = ret.decode() + ret = b'' except: - fail_cnt += 1; + fail_cnt += 1 if (fail_cnt == 20): - break; + break else: - continue; - fail_cnt = 0; + continue + fail_cnt = 0 if (cur == ""): - break; + break if one_by_one: - yield cur; + yield cur else: - res += cur; - yield res; + res += cur + yield res def stream_response_raw(self, input_tokens: List[int], diff --git a/tools/fastllm_pytools/openai_server/fastllm_completion.py b/tools/fastllm_pytools/openai_server/fastllm_completion.py index 67286541..743e9c23 100644 --- a/tools/fastllm_pytools/openai_server/fastllm_completion.py +++ b/tools/fastllm_pytools/openai_server/fastllm_completion.py @@ -104,34 +104,9 @@ async def create_chat_completion( # fastllm 样例中history只能是一问一答, system promt 暂时不支持 if len(conversation) == 0: raise Exception("Empty msg") - - for i in range(len(conversation)): - msg = conversation[i] - if msg.role == "system": - # fastllm 暂时不支持system prompt - continue - elif msg.role == "user": - if i + 1 < len(conversation): - next_msg = conversation[i + 1] - if next_msg.role == "assistant": - history.append((msg.content, next_msg.content)) - else: - # 只能是user、assistant、user、assistant的格式 - raise Exception("fastllm requires that the prompt words must appear alternately in the roles of user and assistant.") - elif msg.role == "assistant": - if i - 1 < 0: - raise Exception("fastllm Not Support assistant prompt in first message") - else: - pre_msg = conversation[i - 1] - if pre_msg.role != "user": - raise Exception("In FastLLM, The message role before the assistant msg must be user") - else: - raise NotImplementedError(f"prompt role {msg.role } not supported yet") - - last_msg = conversation[-1] - if last_msg.role != "user": - raise Exception("last msg role must be user") - query = last_msg.content + messages = [] + for msg in conversation: + messages.append({"role": msg.role, "content": msg.content}) except Exception as e: logging.error("Error in applying chat template from request: %s", e) @@ -147,11 +122,10 @@ async def create_chat_completion( max_length = request.max_tokens if request.max_tokens else 8192 input_token_len = 0; # self.model.get_input_token_len(query, history) #logging.info(request) - logging.info(f"fastllm input: {query}") - logging.info(f"fastllm history: {history}") + logging.info(f"fastllm input message: {messages}") #logging.info(f"input tokens: {input_token_len}") # stream_response 中的结果不包含token的统计信息 - result_generator = self.model.stream_response_async(query, history, + result_generator = self.model.stream_response_async(messages, max_length = max_length, do_sample = True, top_p = request.top_p, top_k = request.top_k, temperature = request.temperature, repeat_penalty = frequency_penalty, one_by_one = True) diff --git a/tools/fastllm_pytools/web_demo.py b/tools/fastllm_pytools/web_demo.py index be566168..bd7500c8 100644 --- a/tools/fastllm_pytools/web_demo.py +++ b/tools/fastllm_pytools/web_demo.py @@ -25,11 +25,13 @@ def parse_args(): def get_model(): args = parse_args() model = make_normal_llm_model(args) + model.set_verbose(True) return model if "messages" not in st.session_state: st.session_state.messages = [] +system_prompt = st.sidebar.text_input("system_prompt", "") max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1) top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01) top_k = st.sidebar.slider("top_k", 1, 50, 1, step = 1) @@ -55,8 +57,15 @@ def get_model(): with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" - for chunk in model.stream_response(prompt, - st.session_state.messages, + messages = [] + if system_prompt != "": + messages.append({"role": "system", "content": system_prompt}) + for his in st.session_state.messages: + messages.append({"role": "user", "content": his[0]}) + messages.append({"role": "assistant", "content": his[1]}) + messages.append({"role": "user", "content": prompt}) + + for chunk in model.stream_response(messages, max_length = max_new_tokens, top_k = top_k, top_p = top_p,