Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: 支持同时运行多个适配器 #685

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/audit/center/apigroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def _do(
headers: dict = {},
**kwargs
):
self._runtime_info['account_id'] = "{}".format(self.ap.im_mgr.bot_account_id)
self._runtime_info['account_id'] = "-1"

url = self.prefix + path
data = json.dumps(data)
Expand Down
14 changes: 13 additions & 1 deletion pkg/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import asyncio
import traceback

import aioconsole

from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
Expand Down Expand Up @@ -72,11 +74,21 @@ async def run(self):
tasks = []

try:


tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

async def interrupt(tasks):
await asyncio.sleep(1.5)
while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit':
pass
for task in tasks:
task.cancel()

await interrupt(tasks)

except asyncio.CancelledError:
pass
Expand Down
5 changes: 4 additions & 1 deletion pkg/core/boot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ async def make_app() -> app.Application:
},
runtime_info={
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": platform_cfg.data["platform-adapter"],
"msg_source": [
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
],
},
)
ap.ctr_mgr = center_v2_api
Expand Down
5 changes: 3 additions & 2 deletions pkg/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ async def _check_output(self, query: entities.Query, result: pipeline_entities.S
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
Expand Down Expand Up @@ -150,7 +151,7 @@ async def process_query(self, query: entities.Query):
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

Expand Down
7 changes: 7 additions & 0 deletions pkg/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..provider.requester import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter


class LauncherTypes(enum.Enum):
Expand Down Expand Up @@ -44,6 +45,9 @@ class Query(pydantic.BaseModel):
message_chain: mirai.MessageChain
"""消息链,platform收到的消息链"""

adapter: msadapter.MessageSourceAdapter
"""适配器对象"""

session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""

Expand All @@ -68,6 +72,9 @@ class Query(pydantic.BaseModel):
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链,从resp_messages包装而得"""

class Config:
arbitrary_types_allowed = True


class Conversation(pydantic.BaseModel):
"""对话"""
Expand Down
7 changes: 5 additions & 2 deletions pkg/core/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mirai

from . import entities
from ..platform import adapter as msadapter


class QueryPool:
Expand All @@ -29,7 +30,8 @@ async def add_query(
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
message_chain: mirai.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:
query = entities.Query(
Expand All @@ -40,7 +42,8 @@ async def add_query(
message_event=message_event,
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None
resp_message_chain=None,
adapter=adapter
)
self.queries.append(query)
self.query_id_counter += 1
Expand Down
2 changes: 1 addition & 1 deletion pkg/pipeline/longtext/longtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def initialize(self):

async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
Expand Down
5 changes: 3 additions & 2 deletions pkg/pipeline/longtext/strategies/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mirai.models.base import MiraiBaseModel

from .. import strategy as strategy_model
from ....core import entities as core_entities


class ForwardMessageDiaplay(MiraiBaseModel):
Expand Down Expand Up @@ -37,7 +38,7 @@ def __str__(self):

class ForwardComponentStrategy(strategy_model.LongTextStrategy):

async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
Expand All @@ -48,7 +49,7 @@ async def process(self, message: str) -> list[MessageComponent]:

node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
Expand Down
3 changes: 2 additions & 1 deletion pkg/pipeline/longtext/strategies/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mirai.models.message import MessageComponent

from .. import strategy as strategy_model
from ....core import entities as core_entities


class Text2ImageStrategy(strategy_model.LongTextStrategy):
Expand All @@ -21,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")

async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
Expand Down
3 changes: 2 additions & 1 deletion pkg/pipeline/longtext/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mirai.models.message import MessageComponent

from ...core import app
from ...core import entities as core_entities


class LongTextStrategy(metaclass=abc.ABCMeta):
Expand All @@ -18,5 +19,5 @@ async def initialize(self):
pass

@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
return []
3 changes: 2 additions & 1 deletion pkg/pipeline/respback/respback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent

await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
query.resp_message_chain,
adapter=query.adapter
)

return entities.StageProcessResult(
Expand Down
2 changes: 1 addition & 1 deletion pkg/pipeline/resprule/resprule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent
use_rule = use_rule[str(query.launcher_id)]

for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement

Expand Down
5 changes: 3 additions & 2 deletions pkg/pipeline/resprule/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import mirai

from ...core import app
from ...core import app, entities as core_entities
from . import entities


Expand All @@ -24,7 +24,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""
Expand Down
8 changes: 5 additions & 3 deletions pkg/pipeline/resprule/rules/atbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class AtBotRule(rule_model.GroupRespondRule):
Expand All @@ -12,11 +13,12 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:

if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class PrefixRule(rule_model.GroupRespondRule):
Expand All @@ -10,7 +11,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']

Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class RandomRespRule(rule_model.GroupRespondRule):
Expand All @@ -12,7 +13,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']

Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/resprule/rules/regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities


class RegExpRule(rule_model.GroupRespondRule):
Expand All @@ -12,7 +13,8 @@ async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']

Expand Down
4 changes: 2 additions & 2 deletions pkg/platform/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def is_muted(self, group_id: int) -> bool:
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注册事件监听器

Expand All @@ -84,7 +84,7 @@ def register_listener(
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注销事件监听器

Expand Down
Loading
Loading