From 9cad94e961c3cc4eb69a7348f3cc835b9d978560 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 11 Feb 2024 23:07:38 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=8C?= =?UTF-8?q?=E6=97=B6=E8=BF=90=E8=A1=8C=E5=A4=9A=E4=B8=AA=E5=B9=B3=E5=8F=B0?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/audit/center/apigroup.py | 2 +- pkg/core/controller.py | 3 +- pkg/core/entities.py | 7 ++ pkg/core/pool.py | 7 +- pkg/pipeline/longtext/longtext.py | 2 +- pkg/pipeline/longtext/strategies/forward.py | 5 +- pkg/pipeline/longtext/strategies/image.py | 3 +- pkg/pipeline/longtext/strategy.py | 3 +- pkg/pipeline/respback/respback.py | 3 +- pkg/pipeline/resprule/resprule.py | 2 +- pkg/pipeline/resprule/rule.py | 5 +- pkg/pipeline/resprule/rules/atbot.py | 8 +- pkg/pipeline/resprule/rules/prefix.py | 4 +- pkg/pipeline/resprule/rules/random.py | 1 + pkg/pipeline/resprule/rules/regexp.py | 1 + pkg/platform/adapter.py | 4 +- pkg/platform/manager.py | 105 ++++++++++++-------- pkg/platform/sources/aiocqhttp.py | 6 +- pkg/platform/sources/nakuru.py | 7 +- pkg/platform/sources/qqbotpy.py | 6 +- pkg/platform/sources/yirimirai.py | 8 +- templates/platform.json | 34 +++++++ 22 files changed, 150 insertions(+), 76 deletions(-) diff --git a/pkg/audit/center/apigroup.py b/pkg/audit/center/apigroup.py index 9d35c05e..10b6d8dd 100644 --- a/pkg/audit/center/apigroup.py +++ b/pkg/audit/center/apigroup.py @@ -34,7 +34,7 @@ async def _do( headers: dict = {}, **kwargs ): - self._runtime_info['account_id'] = "{}".format(self.ap.im_mgr.bot_account_id) + self._runtime_info['account_id'] = "-1" url = self.prefix + path data = json.dumps(data) diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 12403aff..42ef435c 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -70,7 +70,8 @@ async def _check_output(self, query: entities.Query, result: pipeline_entities.S if result.user_notice: await self.ap.im_mgr.send( query.message_event, - result.user_notice + result.user_notice, + query.adapter ) if result.debug_notice: self.ap.logger.debug(result.debug_notice) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 53d515ed..dacb64e0 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -12,6 +12,7 @@ from ..provider.requester import entities from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities +from ..platform import adapter as msadapter class LauncherTypes(enum.Enum): @@ -44,6 +45,9 @@ class Query(pydantic.BaseModel): message_chain: mirai.MessageChain """消息链,platform收到的消息链""" + adapter: msadapter.MessageSourceAdapter + """适配器对象""" + session: typing.Optional[Session] = None """会话对象,由前置处理器设置""" @@ -68,6 +72,9 @@ class Query(pydantic.BaseModel): resp_message_chain: typing.Optional[mirai.MessageChain] = None """回复消息链,从resp_messages包装而得""" + class Config: + arbitrary_types_allowed = True + class Conversation(pydantic.BaseModel): """对话""" diff --git a/pkg/core/pool.py b/pkg/core/pool.py index a5a26423..5c8000dd 100644 --- a/pkg/core/pool.py +++ b/pkg/core/pool.py @@ -5,6 +5,7 @@ import mirai from . import entities +from ..platform import adapter as msadapter class QueryPool: @@ -29,7 +30,8 @@ async def add_query( launcher_id: int, sender_id: int, message_event: mirai.MessageEvent, - message_chain: mirai.MessageChain + message_chain: mirai.MessageChain, + adapter: msadapter.MessageSourceAdapter ) -> entities.Query: async with self.condition: query = entities.Query( @@ -40,7 +42,8 @@ async def add_query( message_event=message_event, message_chain=message_chain, resp_messages=[], - resp_message_chain=None + resp_message_chain=None, + adapter=adapter ) self.queries.append(query) self.query_id_counter += 1 diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 85de47c1..ab70732b 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -52,7 +52,7 @@ async def initialize(self): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: - query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) + query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index d1b5c36c..cfab49d9 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -7,6 +7,7 @@ from mirai.models.base import MiraiBaseModel from .. import strategy as strategy_model +from ....core import entities as core_entities class ForwardMessageDiaplay(MiraiBaseModel): @@ -37,7 +38,7 @@ def __str__(self): class ForwardComponentStrategy(strategy_model.LongTextStrategy): - async def process(self, message: str) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: display = ForwardMessageDiaplay( title="群聊的聊天记录", brief="[聊天记录]", @@ -48,7 +49,7 @@ async def process(self, message: str) -> list[MessageComponent]: node_list = [ ForwardMessageNode( - sender_id=self.ap.im_mgr.bot_account_id, + sender_id=query.adapter.bot_account_id, sender_name='QQ用户', message_chain=MessageChain([message]) ) diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index 1d350f60..af34f4e6 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -12,6 +12,7 @@ from mirai.models.message import MessageComponent from .. import strategy as strategy_model +from ....core import entities as core_entities class Text2ImageStrategy(strategy_model.LongTextStrategy): @@ -21,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): async def initialize(self): self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") - async def process(self, message: str) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())) diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 5c6bfb9c..a1f8a94f 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -6,6 +6,7 @@ from mirai.models.message import MessageComponent from ...core import app +from ...core import entities as core_entities class LongTextStrategy(metaclass=abc.ABCMeta): @@ -18,5 +19,5 @@ async def initialize(self): pass @abc.abstractmethod - async def process(self, message: str) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: return [] diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 52b23ab6..10b2cbac 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -31,7 +31,8 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent await self.ap.im_mgr.send( query.message_event, - query.resp_message_chain + query.resp_message_chain, + adapter=query.adapter ) return entities.StageProcessResult( diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 894bbce1..8f418729 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -47,7 +47,7 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent use_rule = use_rule[str(query.launcher_id)] for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 - res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule) + res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) if res.matching: query.message_chain = res.replacement diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index e530d063..cde9ec3d 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -3,7 +3,7 @@ import mirai -from ...core import app +from ...core import app, entities as core_entities from . import entities @@ -24,7 +24,8 @@ async def match( self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: """判断消息是否匹配规则 """ diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index eefc4891..692bee72 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -4,6 +4,7 @@ from .. import rule as rule_model from .. import entities +from ....core import entities as core_entities class AtBotRule(rule_model.GroupRespondRule): @@ -12,11 +13,12 @@ async def match( self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: - if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']: - message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id)) + if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: + message_chain.remove(mirai.At(query.adapter.bot_account_id)) return entities.RuleJudgeResult( matching=True, replacement=message_chain, diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index b23f5f8f..1b61c138 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -2,6 +2,7 @@ from .. import rule as rule_model from .. import entities +from ....core import entities as core_entities class PrefixRule(rule_model.GroupRespondRule): @@ -10,7 +11,8 @@ async def match( self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: prefixes = rule_dict['prefix'] diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 932d00be..2e845ab0 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -4,6 +4,7 @@ from .. import rule as rule_model from .. import entities +from ....core import entities as core_entities class RandomRespRule(rule_model.GroupRespondRule): diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index 0d621fe4..18a3ce09 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -4,6 +4,7 @@ from .. import rule as rule_model from .. import entities +from ....core import entities as core_entities class RegExpRule(rule_model.GroupRespondRule): diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 5ebec9eb..38c31fe2 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -71,7 +71,7 @@ async def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] ): """注册事件监听器 @@ -84,7 +84,7 @@ def register_listener( def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] ): """注销事件监听器 diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index d94eb963..432d39ed 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -17,11 +17,8 @@ # 控制QQ消息输入输出的类 class PlatformManager: - adapter: msadapter.MessageSourceAdapter = None - - @property - def bot_account_id(self): - return self.adapter.bot_account_id + # adapter: msadapter.MessageSourceAdapter = None + adapters: list[msadapter.MessageSourceAdapter] = [] # modern ap: app.Application = None @@ -29,27 +26,13 @@ def bot_account_id(self): def __init__(self, ap: app.Application = None): self.ap = ap + self.adapters = [] async def initialize(self): from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy - adapter_cls = None - - for adapter in msadapter.preregistered_adapters: - if adapter.name == self.ap.platform_cfg.data['platform-adapter']: - adapter_cls = adapter - break - if adapter_cls is None: - raise Exception('未知的平台适配器: ' + self.ap.platform_cfg.data['platform-adapter']) - - cfg_key = self.ap.platform_cfg.data['platform-adapter'] + '-config' - self.adapter = adapter_cls( - self.ap.platform_cfg.data[cfg_key], - self.ap - ) - - async def on_friend_message(event: FriendMessage): + async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PersonMessageReceived( @@ -68,15 +51,11 @@ async def on_friend_message(event: FriendMessage): launcher_id=event.sender.id, sender_id=event.sender.id, message_event=event, - message_chain=event.message_chain + message_chain=event.message_chain, + adapter=adapter ) - self.adapter.register_listener( - FriendMessage, - on_friend_message - ) - - async def on_stranger_message(event: StrangerMessage): + async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PersonMessageReceived( @@ -96,16 +75,17 @@ async def on_stranger_message(event: StrangerMessage): sender_id=event.sender.id, message_event=event, message_chain=event.message_chain, + adapter=adapter ) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 - if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai': - self.adapter.register_listener( - StrangerMessage, - on_stranger_message - ) + # if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai': + # self.adapter.register_listener( + # StrangerMessage, + # on_stranger_message + # ) - async def on_group_message(event: GroupMessage): + async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.GroupMessageReceived( @@ -124,15 +104,49 @@ async def on_group_message(event: GroupMessage): launcher_id=event.group.id, sender_id=event.sender.id, message_event=event, - message_chain=event.message_chain + message_chain=event.message_chain, + adapter=adapter ) - - self.adapter.register_listener( - GroupMessage, - on_group_message - ) - - async def send(self, event, msg, check_quote=True, check_at_sender=True): + + for adap_cfg in self.ap.platform_cfg.data['platform-adapters']: + if adap_cfg['enable']: + cfg_copy = adap_cfg.copy() + del cfg_copy['enable'] + adapter_name = cfg_copy['adapter'] + del cfg_copy['adapter'] + + found = False + + for adapter in msadapter.preregistered_adapters: + if adapter.name == adapter_name: + found = True + adapter_cls = adapter + + adapter_inst = adapter_cls( + cfg_copy, + self.ap + ) + self.adapters.append(adapter_inst) + + if adapter_name == 'yiri-mirai': + adapter_inst.register_listener( + StrangerMessage, + on_stranger_message + ) + + adapter_inst.register_listener( + FriendMessage, + on_friend_message + ) + adapter_inst.register_listener( + GroupMessage, + on_group_message + ) + + if not found: + raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name) + + async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): @@ -143,7 +157,7 @@ async def send(self, event, msg, check_quote=True, check_at_sender=True): ) ) - await self.adapter.reply_message( + await adapter.reply_message( event, msg, quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False @@ -170,7 +184,10 @@ async def send(self, event, msg, check_quote=True, check_at_sender=True): async def run(self): try: - await self.adapter.run_async() + tasks = [] + for adapter in self.adapters: + tasks.append(adapter.run_async()) + await asyncio.gather(*tasks) except Exception as e: self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 0b5b2aff..a3f4240e 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -240,12 +240,12 @@ async def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None], + callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], ): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(self.event_converter.target2yiri(event)) + return await callback(self.event_converter.target2yiri(event), self) except: traceback.print_exc() @@ -257,7 +257,7 @@ async def on_message(event: aiocqhttp.Event): def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None], + callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 202f3f6c..c9ab23e3 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -257,14 +257,13 @@ def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): try: # 包装函数 async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)): - print(1111) - await callback(self.event_converter.target2yiri(source)) + await callback(self.event_converter.target2yiri(source), self) # 将包装函数和原函数的对应关系存入列表 self.listener_list.append( @@ -284,7 +283,7 @@ async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.even def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 8264a193..6d74d0ea 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -362,14 +362,14 @@ async def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): try: async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]): self.cached_official_messages[str(message.id)] = message - await callback(OfficialEventConverter.target2yiri(message)) + await callback(OfficialEventConverter.target2yiri(message), self) for event_handler in event_handler_mapping[event_type]: setattr(self.bot, event_handler, wrapper) @@ -380,7 +380,7 @@ async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.Dir def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): delattr(self.bot, event_handler_mapping[event_type]) diff --git a/pkg/platform/sources/yirimirai.py b/pkg/platform/sources/yirimirai.py index c43e13fb..7768dcf0 100644 --- a/pkg/platform/sources/yirimirai.py +++ b/pkg/platform/sources/yirimirai.py @@ -87,7 +87,7 @@ async def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): """注册事件监听器 @@ -95,12 +95,14 @@ def register_listener( event_type (typing.Type[mirai.Event]): YiriMirai事件类型 callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 """ - self.bot.on(event_type)(callback) + async def wrapper(event: mirai.Event): + await callback(event, self) + self.bot.on(event_type)(wrapper) def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): """注销事件监听器 diff --git a/templates/platform.json b/templates/platform.json index bef28523..e0d4e0b0 100644 --- a/templates/platform.json +++ b/templates/platform.json @@ -25,6 +25,40 @@ "direct_message" ] }, + "platform-adapters": [ + { + "adapter": "yiri-mirai", + "enable": false, + "host": "127.0.0.1", + "port": 8080, + "verifyKey": "yirimirai", + "qq": 123456789 + }, + { + "adapter": "nakuru", + "enable": false, + "host": "127.0.0.1", + "ws_port": 8080, + "http_port": 5700, + "token": "" + }, + { + "adapter": "aiocqhttp", + "enable": false, + "host": "127.0.0.1", + "port": 8080 + }, + { + "adapter": "qq-botpy", + "enable": false, + "appid": "", + "secret": "", + "intents": [ + "public_guild_messages", + "direct_message" + ] + } + ], "track-function-calls": true, "quote-origin": false, "at-sender": false, From 836df87e18e144c55e0f25b561b3417491e670d6 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 11 Feb 2024 23:11:13 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E5=88=A0=E9=99=A4=E8=BF=87?= =?UTF-8?q?=E6=97=B6=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/boot.py | 5 ++++- templates/platform.json | 26 -------------------------- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/pkg/core/boot.py b/pkg/core/boot.py index ef5fe2ec..426be200 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -88,7 +88,10 @@ async def make_app() -> app.Application: }, runtime_info={ "admin_id": "{}".format(system_cfg.data["admin-sessions"]), - "msg_source": platform_cfg.data["platform-adapter"], + "msg_source": [ + adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown' + for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable'] + ], }, ) ap.ctr_mgr = center_v2_api diff --git a/templates/platform.json b/templates/platform.json index e0d4e0b0..6b4de843 100644 --- a/templates/platform.json +++ b/templates/platform.json @@ -1,30 +1,4 @@ { - "platform-adapter": "yiri-mirai", - "yiri-mirai-config": { - "adapter": "WebSocketAdapter", - "host": "127.0.0.1", - "port": 8080, - "verifyKey": "yirimirai", - "qq": 123456789 - }, - "nakuru-config": { - "host": "127.0.0.1", - "ws_port": 8080, - "http_port": 5700, - "token": "" - }, - "aiocqhttp-config": { - "host": "127.0.0.1", - "port": 8080 - }, - "qq-botpy-config": { - "appid": "", - "secret": "", - "intents": [ - "public_guild_messages", - "direct_message" - ] - }, "platform-adapters": [ { "adapter": "yiri-mirai", From abc19e78b8b9c4fe24aafd23dd3eae0146dda187 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 11 Feb 2024 23:35:05 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E5=91=BD=E4=BB=A4=E8=A1=8C?= =?UTF-8?q?=E9=80=80=E5=87=BA=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 14 +++++++++++++- pkg/pipeline/resprule/rules/random.py | 3 ++- pkg/pipeline/resprule/rules/regexp.py | 3 ++- pkg/platform/manager.py | 12 ++++-------- requirements.txt | 1 + 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/pkg/core/app.py b/pkg/core/app.py index 595b01a8..ee901ba5 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -4,6 +4,8 @@ import asyncio import traceback +import aioconsole + from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.requester import modelmgr as llm_model_mgr @@ -72,11 +74,21 @@ async def run(self): tasks = [] try: + + tasks = [ asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.ctrl.run()) ] - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + async def interrupt(tasks): + await asyncio.sleep(1.5) + while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit': + pass + for task in tasks: + task.cancel() + + await interrupt(tasks) except asyncio.CancelledError: pass diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 2e845ab0..185e03ec 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -13,7 +13,8 @@ async def match( self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index 18a3ce09..4e39d432 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -13,7 +13,8 @@ async def match( self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: regexps = rule_dict['regexp'] diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 432d39ed..f4f423db 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -78,13 +78,6 @@ async def on_stranger_message(event: StrangerMessage, adapter: msadapter.Message adapter=adapter ) - # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 - # if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai': - # self.adapter.register_listener( - # StrangerMessage, - # on_stranger_message - # ) - async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( @@ -187,7 +180,10 @@ async def run(self): tasks = [] for adapter in self.adapters: tasks.append(adapter.run_async()) - await asyncio.gather(*tasks) + + for task in tasks: + asyncio.create_task(task) + except Exception as e: self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") diff --git a/requirements.txt b/requirements.txt index de78dcec..649ed9b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ tiktoken PyYaml aiohttp pydantic +aioconsole \ No newline at end of file From 991a0aa5f66e35b7a1c9db6d414f247e5d6493f4 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Mon, 12 Feb 2024 13:37:41 +0800 Subject: [PATCH 4/4] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dnakuru=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E8=BF=90=E8=A1=8C=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/controller.py | 2 +- pkg/platform/manager.py | 14 +++++++++++++- pkg/platform/sources/nakuru.py | 16 +++++++++------- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 42ef435c..7173939b 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -151,7 +151,7 @@ async def process_query(self, query: entities.Query): except Exception as e: self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - # traceback.print_exc() + traceback.print_exc() finally: self.ap.logger.debug(f"Query {query} processed") diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index f4f423db..139e6356 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -101,8 +101,12 @@ async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSource adapter=adapter ) + index = 0 + for adap_cfg in self.ap.platform_cfg.data['platform-adapters']: if adap_cfg['enable']: + self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}') + index += 1 cfg_copy = adap_cfg.copy() del cfg_copy['enable'] adapter_name = cfg_copy['adapter'] @@ -179,7 +183,14 @@ async def run(self): try: tasks = [] for adapter in self.adapters: - tasks.append(adapter.run_async()) + async def exception_wrapper(adapter): + try: + await adapter.run_async() + except Exception as e: + self.ap.logger.error('平台适配器运行出错: ' + str(e)) + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + tasks.append(exception_wrapper(adapter)) for task in tasks: asyncio.create_task(task) @@ -187,3 +198,4 @@ async def run(self): except Exception as e: self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index c9ab23e3..0a419a06 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -1,4 +1,5 @@ -from __future__ import annotations +# 加了之后会导致:https://github.com/Lxns-Network/nakuru-project/issues/25 +# from __future__ import annotations import asyncio import typing @@ -12,7 +13,6 @@ from .. import adapter as adapter_model from ...pipeline.longtext.strategies import forward -from ...core import app class NakuruProjectMessageConverter(adapter_model.MessageConverter): @@ -170,11 +170,11 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): listener_list: list[dict] - ap: app.Application + # ap: app.Application cfg: dict - def __init__(self, cfg: dict, ap: app.Application): + def __init__(self, cfg: dict, ap): """初始化nakuru-project的对象""" cfg['port'] = cfg['ws_port'] del cfg['ws_port'] @@ -261,8 +261,10 @@ def register_listener( ): try: + source_cls = NakuruProjectEventConverter.yiri2target(event_type) + # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)): + async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): await callback(self.event_converter.target2yiri(source), self) # 将包装函数和原函数的对应关系存入列表 @@ -275,7 +277,7 @@ async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.even ) # 注册监听器 - self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper) + self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: traceback.print_exc() raise e @@ -325,7 +327,7 @@ async def run_async(self): await self.bot._run() self.ap.logger.info("运行 Nakuru 适配器") while True: - await asyncio.sleep(100) + await asyncio.sleep(1) def kill(self) -> bool: return False \ No newline at end of file