Skip to content

Commit

Permalink
feat: message.content 支持 mirai.MessageChain 对象 (#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Mar 31, 2024
1 parent 2e9229a commit 8b00373
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ bard.json
res/instance_id.json
.DS_Store
/data
botpy.log
botpy.log*
16 changes: 12 additions & 4 deletions pkg/pipeline/cntfilter/cntfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,17 @@ async def process(
query
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
query.resp_messages[-1].content,
query
)
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
13 changes: 12 additions & 1 deletion pkg/pipeline/longtext/longtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,19 @@ async def initialize(self):
await self.strategy_impl.initialize()

async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
# 检查是否包含非 Plain 组件
contains_non_plain = False

for msg in query.resp_message_chain:
if not isinstance(msg, Plain):
contains_non_plain = True
break

if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))

return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
Expand Down
2 changes: 1 addition & 1 deletion pkg/pipeline/process/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def handle(
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=str(mc),
content=mc,
)
)

Expand Down
5 changes: 4 additions & 1 deletion pkg/pipeline/wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ async def process(
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
else:
query.resp_message_chain = query.resp_messages[-1].content

yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
Expand Down
6 changes: 4 additions & 2 deletions pkg/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import enum
import pydantic

import mirai


class FunctionCall(pydantic.BaseModel):
name: str
Expand All @@ -28,7 +30,7 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""

content: typing.Optional[str] = None
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
"""内容"""

function_call: typing.Optional[FunctionCall] = None
Expand All @@ -41,7 +43,7 @@ class Message(pydantic.BaseModel):

def readable_str(self) -> str:
if self.content is not None:
return self.content
return str(self.content)
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None:
Expand Down

0 comments on commit 8b00373

Please sign in to comment.