From 80858672b0042404c56a91c1fcf2ff2ced0b730e Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 20 Feb 2024 22:56:42 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=8E=A7=E5=88=B6=E5=8F=B0=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E8=AF=B7=E6=B1=82=E5=93=8D=E5=BA=94=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/process/handler.py | 9 +++++++++ pkg/pipeline/process/handlers/chat.py | 5 +++++ pkg/pipeline/process/handlers/command.py | 4 ++++ pkg/pipeline/process/process.py | 13 +++++++++---- pkg/provider/entities.py | 10 ++++++++++ 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py index 6d19e039..879b4cfe 100644 --- a/pkg/pipeline/process/handler.py +++ b/pkg/pipeline/process/handler.py @@ -23,3 +23,12 @@ async def handle( query: core_entities.Query, ) -> entities.StageProcessResult: raise NotImplementedError + + def cut_str(self, s: str) -> str: + """ + 取字符串第一行,最多20个字符,若有多行,或超过20个字符,则加省略号 + """ + s0 = s.split('\n')[0] + if len(s0) > 20 or '\n' in s: + s0 = s0[:20] + '...' + return s0 diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 26c99a2e..b3e8fa18 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -78,6 +78,8 @@ async def handle( async for result in query.use_model.requester.request(query): query.resp_messages.append(result) + self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + if result.content is not None: text_length += len(result.content) @@ -86,6 +88,9 @@ async def handle( new_query=query ) except Exception as e: + + self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') + yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index d5873e38..7a669c50 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -91,6 +91,8 @@ async def handle( ) ) + self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}') + yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query @@ -106,6 +108,8 @@ async def handle( ) ) + self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') + yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index c24fdac2..6dbb7009 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -34,7 +34,12 @@ async def process( self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}") - if message_text.startswith('!') or message_text.startswith('!'): - return self.cmd_handler.handle(query) - else: - return self.chat_handler.handle(query) + async def generator(): + if message_text.startswith('!') or message_text.startswith('!'): + async for result in self.cmd_handler.handle(query): + yield result + else: + async for result in self.chat_handler.handle(query): + yield result + + return generator() diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 44866e2e..2db29d16 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -31,3 +31,13 @@ class Message(pydantic.BaseModel): tool_calls: typing.Optional[list[ToolCall]] = None tool_call_id: typing.Optional[str] = None + + def readable_str(self) -> str: + if self.content is not None: + return self.content + elif self.function_call is not None: + return f'{self.function_call.name}({self.function_call.arguments})' + elif self.tool_calls is not None: + return f'调用工具: {self.tool_calls[0].id}' + else: + return '未知消息'