Skip to content

Commit

Permalink
Merge pull request #685 from RockChinQ/feat/run-multi-adapter
Browse files Browse the repository at this point in the history
Feat: 支持同时运行多个适配器
  • Loading branch information
RockChinQ authored Feb 12, 2024
2 parents f951625 + 991a0aa commit 8af1741
Show file tree
Hide file tree
Showing 25 changed files with 192 additions and 116 deletions.
2 changes: 1 addition & 1 deletion pkg/audit/center/apigroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def _do(
headers: dict = {},
**kwargs
):
self._runtime_info['account_id'] = "{}".format(self.ap.im_mgr.bot_account_id)
self._runtime_info['account_id'] = "-1"

url = self.prefix + path
data = json.dumps(data)
Expand Down
14 changes: 13 additions & 1 deletion pkg/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import asyncio
import traceback

import aioconsole

from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
Expand Down Expand Up @@ -72,11 +74,21 @@ async def run(self):
tasks = []

try:


tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

async def interrupt(tasks):
await asyncio.sleep(1.5)
while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit':
pass
for task in tasks:
task.cancel()

await interrupt(tasks)

except asyncio.CancelledError:
pass
Expand Down
5 changes: 4 additions & 1 deletion pkg/core/boot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ async def make_app() -> app.Application:
},
runtime_info={
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": platform_cfg.data["platform-adapter"],
"msg_source": [
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
],
},
)
ap.ctr_mgr = center_v2_api
Expand Down
5 changes: 3 additions & 2 deletions pkg/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ async def _check_output(self, query: entities.Query, result: pipeline_entities.S
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
Expand Down Expand Up @@ -150,7 +151,7 @@ async def process_query(self, query: entities.Query):
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

Expand Down
7 changes: 7 additions & 0 deletions pkg/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..provider.requester import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter


class LauncherTypes(enum.Enum):
Expand Down Expand Up @@ -44,6 +45,9 @@ class Query(pydantic.BaseModel):
message_chain: mirai.MessageChain
"""消息链,platform收到的消息链"""

adapter: msadapter.MessageSourceAdapter
"""适配器对象"""

session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""

Expand All @@ -68,6 +72,9 @@ class Query(pydantic.BaseModel):
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链,从resp_messages包装而得"""

class Config:
arbitrary_types_allowed = True


class Conversation(pydantic.BaseModel):
"""对话"""
Expand Down
7 changes: 5 additions & 2 deletions pkg/core/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mirai

from . import entities
from ..platform import adapter as msadapter


class QueryPool:
Expand All @@ -29,7 +30,8 @@ async def add_query(
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
message_chain: mirai.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:
query = entities.Query(
Expand All @@ -40,7 +42,8 @@ async def add_query(
message_event=message_event,
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None
resp_message_chain=None,
adapter=adapter
)
self.queries.append(query)
self.query_id_counter += 1
Expand Down
2 changes: 1 addition & 1 deletion pkg/pipeline/longtext/longtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def initialize(self):

async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
Expand Down
5 changes: 3 additions & 2 deletions pkg/pipeline/longtext/strategies/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mirai.models.base import MiraiBaseModel

from .. import strategy as strategy_model
from ....core import entities as core_entities


class ForwardMessageDiaplay(MiraiBaseModel):
Expand Down Expand Up @@ -37,7 +38,7 @@ def __str__(self):

class ForwardComponentStrategy(strategy_model.LongTextStrategy):

async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
Expand All @@ -48,7 +49,7 @@ async def process(self, message: str) -> list[MessageComponent]:

node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
Expand Down
3 changes: 2 additions & 1 deletion pkg/pipeline/longtext/strategies/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mirai.models.message import MessageComponent

from .. import strategy as strategy_model
from ....core import entities as core_entities


class Text2ImageStrategy(strategy_model.LongTextStrategy):
Expand All @@ -21,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")

async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
Expand Down
3 changes: 2 additions & 1 deletion pkg/pipeline/longtext/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mirai.models.message import MessageComponent

from ...core import app
from ...core import entities as core_entities


class LongTextStrategy(metaclass=abc.ABCMeta):
Expand All @@ -18,5 +19,5 @@ async def initialize(self):
pass

@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
return []
3 changes: 2 additions & 1 deletion pkg/pipeline/respback/respback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent

await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
query.resp_message_chain,
adapter=query.adapter
)

return entities.StageProcessResult(
Expand Down
2 changes: 1 addition & 1 deletion pkg/pipeline/resprule/resprule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent
use_rule = use_rule[str(query.launcher_id)]

for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement

Expand Down
5 changes: 3 additions & 2 deletions pkg/pipeline/resprule/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import mirai

from ...core import app
from ...core import app, entities as core_entities
from . import entities


Expand All @@ -24,7 +24,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""
Expand Down
8 changes: 5 additions & 3 deletions pkg/pipeline/resprule/rules/atbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class AtBotRule(rule_model.GroupRespondRule):
Expand All @@ -12,11 +13,12 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:

if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class PrefixRule(rule_model.GroupRespondRule):
Expand All @@ -10,7 +11,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']

Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class RandomRespRule(rule_model.GroupRespondRule):
Expand All @@ -12,7 +13,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']

Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class RegExpRule(rule_model.GroupRespondRule):
Expand All @@ -12,7 +13,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']

Expand Down
4 changes: 2 additions & 2 deletions pkg/platform/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def is_muted(self, group_id: int) -> bool:
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注册事件监听器
Expand All @@ -84,7 +84,7 @@ def register_listener(
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注销事件监听器
Expand Down
Loading

0 comments on commit 8af1741

Please sign in to comment.