From 8d084427d2b7c63fe62ff173d80ae7da8218a70e Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 26 Jan 2024 15:51:49 +0800 Subject: [PATCH 01/10] =?UTF-8?q?refactor:=20=E8=AF=B7=E6=B1=82=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=8E=A7=E5=88=B6=E6=B5=81=E5=9F=BA=E7=A1=80=E6=9E=B6?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/boot/log.py | 54 ----- pkg/{boot => core}/__init__.py | 0 pkg/{boot => core}/app.py | 19 +- pkg/{boot => core}/boot.py | 22 +- .../misc.py => core/bootutils/__init__.py} | 0 pkg/{boot => core/bootutils}/config.py | 4 +- pkg/{boot => core/bootutils}/deps.py | 0 pkg/{boot => core/bootutils}/files.py | 0 pkg/core/bootutils/log.py | 56 +++++ pkg/core/bootutils/misc.py | 0 pkg/core/controller.py | 84 ++++++++ pkg/core/entities.py | 41 ++++ pkg/core/pool.py | 52 +++++ pkg/openai/manager.py | 2 +- pkg/pipeline/__init__.py | 0 pkg/pipeline/bansess/__init__.py | 0 pkg/pipeline/bansess/bansess.py | 76 +++++++ pkg/pipeline/cntfilter/__init__.py | 0 pkg/pipeline/cntfilter/cntfilter.py | 128 ++++++++++++ pkg/pipeline/cntfilter/entities.py | 64 ++++++ pkg/pipeline/cntfilter/filter.py | 34 +++ pkg/pipeline/cntfilter/filters/__init__.py | 0 .../cntfilter/filters/baiduexamine.py | 61 ++++++ pkg/pipeline/cntfilter/filters/banwords.py | 44 ++++ pkg/pipeline/cntfilter/filters/cntignore.py | 43 ++++ pkg/pipeline/entities.py | 38 ++++ pkg/pipeline/longtext/__init__.py | 0 pkg/pipeline/longtext/longtext.py | 57 +++++ pkg/pipeline/longtext/strategies/__init__.py | 0 pkg/pipeline/longtext/strategies/forward.py | 62 ++++++ pkg/pipeline/longtext/strategies/image.py | 197 ++++++++++++++++++ pkg/pipeline/longtext/strategy.py | 22 ++ pkg/pipeline/resprule/__init__.py | 0 pkg/pipeline/resprule/entities.py | 9 + pkg/pipeline/resprule/resprule.py | 62 ++++++ pkg/pipeline/resprule/rule.py | 31 +++ pkg/pipeline/resprule/rules/__init__.py | 0 pkg/pipeline/resprule/rules/atbot.py | 28 +++ pkg/pipeline/resprule/rules/prefix.py | 29 +++ pkg/pipeline/resprule/rules/random.py | 22 ++ pkg/pipeline/resprule/rules/regexp.py | 31 +++ pkg/pipeline/stage.py | 43 ++++ pkg/pipeline/stagemgr.py | 47 +++++ pkg/qqbot/bansess/bansess.py | 2 +- pkg/qqbot/cntfilter/cntfilter.py | 2 +- pkg/qqbot/cntfilter/filter.py | 2 +- pkg/qqbot/longtext/longtext.py | 2 +- pkg/qqbot/longtext/strategy.py | 2 +- pkg/qqbot/manager.py | 92 ++------ pkg/qqbot/process.py | 2 +- pkg/qqbot/ratelim/algo.py | 2 +- pkg/qqbot/ratelim/ratelim.py | 2 +- pkg/qqbot/resprule/resprule.py | 2 +- pkg/qqbot/resprule/rule.py | 2 +- start.py | 2 +- 55 files changed, 1430 insertions(+), 146 deletions(-) delete mode 100644 pkg/boot/log.py rename pkg/{boot => core}/__init__.py (100%) rename pkg/{boot => core}/app.py (65%) rename pkg/{boot => core}/boot.py (88%) rename pkg/{boot/misc.py => core/bootutils/__init__.py} (100%) rename pkg/{boot => core/bootutils}/config.py (85%) rename pkg/{boot => core/bootutils}/deps.py (100%) rename pkg/{boot => core/bootutils}/files.py (100%) create mode 100644 pkg/core/bootutils/log.py create mode 100644 pkg/core/bootutils/misc.py create mode 100644 pkg/core/controller.py create mode 100644 pkg/core/entities.py create mode 100644 pkg/core/pool.py create mode 100644 pkg/pipeline/__init__.py create mode 100644 pkg/pipeline/bansess/__init__.py create mode 100644 pkg/pipeline/bansess/bansess.py create mode 100644 pkg/pipeline/cntfilter/__init__.py create mode 100644 pkg/pipeline/cntfilter/cntfilter.py create mode 100644 pkg/pipeline/cntfilter/entities.py create mode 100644 pkg/pipeline/cntfilter/filter.py create mode 100644 pkg/pipeline/cntfilter/filters/__init__.py create mode 100644 pkg/pipeline/cntfilter/filters/baiduexamine.py create mode 100644 pkg/pipeline/cntfilter/filters/banwords.py create mode 100644 pkg/pipeline/cntfilter/filters/cntignore.py create mode 100644 pkg/pipeline/entities.py create mode 100644 pkg/pipeline/longtext/__init__.py create mode 100644 pkg/pipeline/longtext/longtext.py create mode 100644 pkg/pipeline/longtext/strategies/__init__.py create mode 100644 pkg/pipeline/longtext/strategies/forward.py create mode 100644 pkg/pipeline/longtext/strategies/image.py create mode 100644 pkg/pipeline/longtext/strategy.py create mode 100644 pkg/pipeline/resprule/__init__.py create mode 100644 pkg/pipeline/resprule/entities.py create mode 100644 pkg/pipeline/resprule/resprule.py create mode 100644 pkg/pipeline/resprule/rule.py create mode 100644 pkg/pipeline/resprule/rules/__init__.py create mode 100644 pkg/pipeline/resprule/rules/atbot.py create mode 100644 pkg/pipeline/resprule/rules/prefix.py create mode 100644 pkg/pipeline/resprule/rules/random.py create mode 100644 pkg/pipeline/resprule/rules/regexp.py create mode 100644 pkg/pipeline/stage.py create mode 100644 pkg/pipeline/stagemgr.py diff --git a/pkg/boot/log.py b/pkg/boot/log.py deleted file mode 100644 index e0a15daa..00000000 --- a/pkg/boot/log.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -import os -import sys -import time - -import colorlog - - -log_colors_config = { - 'DEBUG': 'green', # cyan white - 'INFO': 'white', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'cyan', -} - - -async def init_logging() -> logging.Logger: - - level = logging.INFO - - if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']: - level = logging.DEBUG - - log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - - qcg_logger = logging.getLogger("qcg") - - qcg_logger.setLevel(level) - - log_handlers: logging.Handler = [ - logging.StreamHandler(sys.stdout), - logging.FileHandler(log_file_name) - ] - - for handler in log_handlers: - handler.setLevel(level) - handler.setFormatter( - colorlog.ColoredFormatter( - fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - log_colors=log_colors_config - ) - ) - qcg_logger.addHandler(handler) - - logging.basicConfig(level=level, # 设置日志输出格式 - format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", - # 日志输出的格式 - # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 - ) - - return qcg_logger \ No newline at end of file diff --git a/pkg/boot/__init__.py b/pkg/core/__init__.py similarity index 100% rename from pkg/boot/__init__.py rename to pkg/core/__init__.py diff --git a/pkg/boot/app.py b/pkg/core/app.py similarity index 65% rename from pkg/boot/app.py rename to pkg/core/app.py index df9e92b9..8c0a0c58 100644 --- a/pkg/boot/app.py +++ b/pkg/core/app.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import asyncio from ..qqbot import manager as qqbot_mgr from ..openai import manager as openai_mgr @@ -8,6 +9,8 @@ from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr from ..plugin import host as plugin_host +from . import pool, controller +from ..pipeline import stagemgr class Application: @@ -23,16 +26,24 @@ class Application: ctr_mgr: center_mgr.V2CenterAPI = None + query_pool: pool.QueryPool = None + + ctrl: controller.Controller = None + + stage_mgr: stagemgr.StageManager = None + logger: logging.Logger = None def __init__(self): pass - async def initialize(self): - await self.im_mgr.initialize() - async def run(self): # TODO make it async plugin_host.initialize_plugins() - await self.im_mgr.run() \ No newline at end of file + tasks = [ + asyncio.create_task(self.im_mgr.run()), + asyncio.create_task(self.ctrl.run()) + ] + + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) diff --git a/pkg/boot/boot.py b/pkg/core/boot.py similarity index 88% rename from pkg/boot/boot.py rename to pkg/core/boot.py index b7cc0f38..10fc51b3 100644 --- a/pkg/boot/boot.py +++ b/pkg/core/boot.py @@ -3,12 +3,15 @@ import os import sys -from . import files -from . import deps -from . import log -from . import config +from .bootutils import files +from .bootutils import deps +from .bootutils import log +from .bootutils import config from . import app +from . import pool +from . import controller +from ..pipeline import stagemgr from ..audit import identifier from ..database import manager as db_mgr from ..openai import manager as llm_mgr @@ -86,6 +89,8 @@ async def make_app() -> app.Application: ap.cfg_mgr = cfg_mgr ap.tips_mgr = tips_mgr + ap.query_pool = pool.QueryPool() + center_v2_api = center_v2.V2CenterAPI( basic_info={ "host_id": identifier.identifier['host_id'], @@ -111,8 +116,16 @@ async def make_app() -> app.Application: llm_session.load_sessions() im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) + await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst + stage_mgr = stagemgr.StageManager(ap) + await stage_mgr.initialize() + ap.stage_mgr = stage_mgr + + ctrl = controller.Controller(ap) + ap.ctrl = ctrl + # TODO make it async plugin_host.load_plugins() # plugin_host.initialize_plugins() @@ -122,5 +135,4 @@ async def make_app() -> app.Application: async def main(): app_inst = await make_app() - await app_inst.initialize() await app_inst.run() diff --git a/pkg/boot/misc.py b/pkg/core/bootutils/__init__.py similarity index 100% rename from pkg/boot/misc.py rename to pkg/core/bootutils/__init__.py diff --git a/pkg/boot/config.py b/pkg/core/bootutils/config.py similarity index 85% rename from pkg/boot/config.py rename to pkg/core/bootutils/config.py index 3f796214..f1471ae5 100644 --- a/pkg/boot/config.py +++ b/pkg/core/bootutils/config.py @@ -1,7 +1,7 @@ import json -from ..config import manager as config_mgr -from ..config.impls import pymodule +from ...config import manager as config_mgr +from ...config.impls import pymodule load_python_module_config = config_mgr.load_python_module_config diff --git a/pkg/boot/deps.py b/pkg/core/bootutils/deps.py similarity index 100% rename from pkg/boot/deps.py rename to pkg/core/bootutils/deps.py diff --git a/pkg/boot/files.py b/pkg/core/bootutils/files.py similarity index 100% rename from pkg/boot/files.py rename to pkg/core/bootutils/files.py diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py new file mode 100644 index 00000000..4bc0e4de --- /dev/null +++ b/pkg/core/bootutils/log.py @@ -0,0 +1,56 @@ +import logging +import os +import sys +import time + +import colorlog + + +log_colors_config = { + "DEBUG": "green", # cyan white + "INFO": "white", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "cyan", +} + + +async def init_logging() -> logging.Logger: + level = logging.INFO + + if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: + level = logging.DEBUG + + log_file_name = "logs/qcg-%s.log" % time.strftime( + "%Y-%m-%d-%H-%M-%S", time.localtime() + ) + + qcg_logger = logging.getLogger("qcg") + + qcg_logger.setLevel(level) + + color_formatter = colorlog.ColoredFormatter( + fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors=log_colors_config, + ) + + stream_handler = logging.StreamHandler(sys.stdout) + + log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)] + + for handler in log_handlers: + handler.setLevel(level) + handler.setFormatter(color_formatter) + qcg_logger.addHandler(handler) + + logging.basicConfig( + level=logging.INFO, # 设置日志输出格式 + format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", + # 日志输出的格式 + # -8表示占位符,让输出左对齐,输出长度都为8位 + datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式 + handlers=[logging.NullHandler()], + ) + + return qcg_logger diff --git a/pkg/core/bootutils/misc.py b/pkg/core/bootutils/misc.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/core/controller.py b/pkg/core/controller.py new file mode 100644 index 00000000..2470cbbd --- /dev/null +++ b/pkg/core/controller.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio +import traceback + +from . import app, entities +from ..pipeline import entities as pipeline_entities + +DEFAULT_QUERY_CONCURRENCY = 10 + + +class Controller: + """总控制器 + """ + ap: app.Application + + semaphore: asyncio.Semaphore = None + """请求并发控制信号量""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY) + + async def consumer(self): + """事件处理循环 + """ + while True: + selected_query: entities.Query = None + + # 取请求 + async with self.ap.query_pool: + queries: list[entities.Query] = self.ap.query_pool.queries + + if queries: + selected_query = queries.pop(0) # FCFS + else: + await self.ap.query_pool.condition.wait() + continue + + if selected_query: + async def _process_query(selected_query): + async with self.semaphore: + await self.process_query(selected_query) + + asyncio.create_task(_process_query(selected_query)) + + async def process_query(self, query: entities.Query): + """处理请求 + """ + self.ap.logger.debug(f"Processing query {query}") + + try: + for stage_container in self.ap.stage_mgr.stage_containers: + res = await stage_container.inst.process(query, stage_container.inst_name) + + self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}") + + if res.user_notice: + await self.ap.im_mgr.send( + query.message_event, + res.user_notice + ) + if res.debug_notice: + self.ap.logger.debug(res.debug_notice) + if res.console_notice: + self.ap.logger.info(res.console_notice) + + if res.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif res.result_type == pipeline_entities.ResultType.CONTINUE: + query = res.new_query + continue + + except Exception as e: + self.ap.logger.error(f"处理请求时出错 {query}: {e}") + traceback.print_exc() + finally: + self.ap.logger.debug(f"Query {query} processed") + + async def run(self): + """运行控制器 + """ + await self.consumer() diff --git a/pkg/core/entities.py b/pkg/core/entities.py new file mode 100644 index 00000000..505112ff --- /dev/null +++ b/pkg/core/entities.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import enum +import typing + +import pydantic +import mirai + + +class LauncherTypes(enum.Enum): + + PERSON = 'person' + """私聊""" + + GROUP = 'group' + """群聊""" + + +class Query(pydantic.BaseModel): + """一次请求的信息封装""" + + query_id: int + """请求ID""" + + launcher_type: LauncherTypes + """会话类型""" + + launcher_id: int + """会话ID""" + + sender_id: int + """发送者ID""" + + message_event: mirai.MessageEvent + """事件""" + + message_chain: mirai.MessageChain + """消息链""" + + resp_message_chain: typing.Optional[mirai.MessageChain] = None + """回复消息链""" diff --git a/pkg/core/pool.py b/pkg/core/pool.py new file mode 100644 index 00000000..3d949292 --- /dev/null +++ b/pkg/core/pool.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import asyncio + +import mirai + +from . import entities + + +class QueryPool: + + query_id_counter: int = 0 + + pool_lock: asyncio.Lock + + queries: list[entities.Query] + + condition: asyncio.Condition + + def __init__(self): + self.query_id_counter = 0 + self.pool_lock = asyncio.Lock() + self.queries = [] + self.condition = asyncio.Condition(self.pool_lock) + + async def add_query( + self, + launcher_type: entities.LauncherTypes, + launcher_id: int, + sender_id: int, + message_event: mirai.MessageEvent, + message_chain: mirai.MessageChain + ) -> entities.Query: + async with self.condition: + query = entities.Query( + query_id=self.query_id_counter, + launcher_type=launcher_type, + launcher_id=launcher_id, + sender_id=sender_id, + message_event=message_event, + message_chain=message_chain + ) + self.queries.append(query) + self.query_id_counter += 1 + self.condition.notify_all() + + async def __aenter__(self): + await self.pool_lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.pool_lock.release() diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 3fd53be6..e070a29f 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -10,7 +10,7 @@ from ..audit import gatherer from ..openai import modelmgr from ..openai.api import model as api_model -from ..boot import app +from ..core import app class OpenAIInteract: diff --git a/pkg/pipeline/__init__.py b/pkg/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/bansess/__init__.py b/pkg/pipeline/bansess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py new file mode 100644 index 00000000..a0f63c36 --- /dev/null +++ b/pkg/pipeline/bansess/bansess.py @@ -0,0 +1,76 @@ +from __future__ import annotations +import re + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class('BanSessionCheckStage') +class BanSessionCheckStage(stage.PipelineStage): + + banlist_mgr: cfg_mgr.ConfigManager + + async def initialize(self): + self.banlist_mgr = await cfg_mgr.load_python_module_config( + "banlist.py", + "res/templates/banlist-template.py" + ) + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + + if not self.banlist_mgr.data['enable']: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + result = False + + if query.launcher_type == 'group': + if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应 + result = True + # 检查是否显式声明发起人QQ要被person忽略 + elif query.sender_id in self.banlist_mgr.data['person']: + result = True + else: + for group_rule in self.banlist_mgr.data['group']: + if type(group_rule) == int: + if group_rule == query.launcher_id: + result = True + elif type(group_rule) == str: + if group_rule.startswith('!'): + reg_str = group_rule[1:] + if re.match(reg_str, str(query.launcher_id)): + result = False + break + else: + if re.match(group_rule, str(query.launcher_id)): + result = True + elif query.launcher_type == 'person': + if not self.banlist_mgr.data['enable_private']: + result = True + else: + for person_rule in self.banlist_mgr.data['person']: + if type(person_rule) == int: + if person_rule == query.launcher_id: + result = True + elif type(person_rule) == str: + if person_rule.startswith('!'): + reg_str = person_rule[1:] + if re.match(reg_str, str(query.launcher_id)): + result = False + break + else: + if re.match(person_rule, str(query.launcher_id)): + result = True + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT, + new_query=query, + debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else '' + ) diff --git a/pkg/pipeline/cntfilter/__init__.py b/pkg/pipeline/cntfilter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py new file mode 100644 index 00000000..0025b00a --- /dev/null +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr +from . import filter, entities as filter_entities +from .filters import cntignore, banwords, baiduexamine + + +@stage.stage_class('PostContentFilterStage') +@stage.stage_class('PreContentFilterStage') +class ContentFilterStage(stage.PipelineStage): + + filter_chain: list[filter.ContentFilter] + + def __init__(self, ap: app.Application): + self.filter_chain = [] + super().__init__(ap) + + async def initialize(self): + self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + if self.ap.cfg_mgr.data['sensitive_word_filter']: + self.filter_chain.append(banwords.BanWordFilter(self.ap)) + + if self.ap.cfg_mgr.data['baidu_check']: + self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + + for filter in self.filter_chain: + await filter.initialize() + + async def _pre_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm前处理消息 + 只要有一个不通过就不放行,只放行 PASS 的消息 + """ + if not self.ap.cfg_mgr.data['income_msg_check']: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + for filter in self.filter_chain: + if filter_entities.EnableStage.PRE in filter.enable_stages: + result = await filter.process(message) + + if result.level in [ + filter_entities.ResultLevel.BLOCK, + filter_entities.ResultLevel.MASKED + ]: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 + message = result.replacement + + query.message_chain = mirai.MessageChain( + mirai.Plain(message) + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def _post_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm后处理响应 + 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter + """ + for filter in self.filter_chain: + if filter_entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) + + if result.level == filter_entities.ResultLevel.BLOCK: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + filter_entities.ResultLevel.PASS, + filter_entities.ResultLevel.MASKED + ]: + message = result.replacement + + query.message_chain = mirai.MessageChain( + mirai.Plain(message) + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + """处理 + """ + if stage_inst_name == 'PreContentFilterStage': + return await self._pre_process( + str(query.message_chain).strip(), + query + ) + elif stage_inst_name == 'PostContentFilterStage': + return await self._post_process( + str(query.message_chain).strip(), + query + ) + else: + raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py new file mode 100644 index 00000000..7ab05675 --- /dev/null +++ b/pkg/pipeline/cntfilter/entities.py @@ -0,0 +1,64 @@ + +import typing +import enum + +import pydantic + + +class ResultLevel(enum.Enum): + """结果等级""" + PASS = enum.auto() + """通过""" + + WARN = enum.auto() + """警告""" + + MASKED = enum.auto() + """已掩去""" + + BLOCK = enum.auto() + """阻止""" + + +class EnableStage(enum.Enum): + """启用阶段""" + PRE = enum.auto() + """预处理""" + + POST = enum.auto() + """后处理""" + + +class FilterResult(pydantic.BaseModel): + level: ResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """不通过时,用户提示消息""" + + console_notice: str + """不通过时,控制台提示消息""" + + +class ManagerResultLevel(enum.Enum): + """处理器结果等级""" + CONTINUE = enum.auto() + """继续""" + + INTERRUPT = enum.auto() + """中断""" + +class FilterManagerResult(pydantic.BaseModel): + + level: ManagerResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """用户提示消息""" + + console_notice: str + """控制台提示消息""" diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py new file mode 100644 index 00000000..57792145 --- /dev/null +++ b/pkg/pipeline/cntfilter/filter.py @@ -0,0 +1,34 @@ +# 内容过滤器的抽象类 +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class ContentFilter(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + @property + def enable_stages(self): + """启用的阶段 + """ + return [ + entities.EnableStage.PRE, + entities.EnableStage.POST + ] + + async def initialize(self): + """初始化过滤器 + """ + pass + + @abc.abstractmethod + async def process(self, message: str) -> entities.FilterResult: + """处理消息 + """ + raise NotImplementedError diff --git a/pkg/pipeline/cntfilter/filters/__init__.py b/pkg/pipeline/cntfilter/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py new file mode 100644 index 00000000..a658897b --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import aiohttp + +from .. import entities +from .. import filter as filter_model + + +BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" +BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" + + +class BaiduCloudExamine(filter_model.ContentFilter): + """百度云内容审核""" + + async def _get_token(self) -> str: + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.ap.cfg_mgr.data['baidu_api_key'], + "client_secret": self.ap.cfg_mgr.data['baidu_secret_key'] + } + ) as resp: + return (await resp.json())['access_token'] + + async def process(self, message: str) -> entities.FilterResult: + + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, + data=f"text={message}".encode('utf-8') + ) as resp: + result = await resp.json() + + if "error_code" in result: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='', + console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + ) + else: + conclusion = result["conclusion"] + + if conclusion in ("合规"): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f"百度云判定结果:{conclusion}" + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'], + console_notice=f"百度云判定结果:{conclusion}" + ) \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py new file mode 100644 index 00000000..9451c7b8 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import re + +from .. import filter as filter_model +from .. import entities +from ....config import manager as cfg_mgr + + +class BanWordFilter(filter_model.ContentFilter): + """根据内容禁言""" + + sensitive: cfg_mgr.ConfigManager + + async def initialize(self): + self.sensitive = await cfg_mgr.load_json_config( + "sensitive.json", + "res/templates/sensitive-template.json" + ) + + async def process(self, message: str) -> entities.FilterResult: + found = False + + for word in self.sensitive.data['words']: + match = re.findall(word, message) + + if len(match) > 0: + found = True + + for i in range(len(match)): + if self.sensitive.data['mask_word'] == "": + message = message.replace( + match[i], self.sensitive.data['mask'] * len(match[i]) + ) + else: + message = message.replace( + match[i], self.sensitive.data['mask_word'] + ) + + return entities.FilterResult( + level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, + replacement=message, + user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py new file mode 100644 index 00000000..81408868 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import re + +from .. import entities +from .. import filter as filter_model + + +class ContentIgnore(filter_model.ContentFilter): + """根据内容忽略消息""" + + @property + def enable_stages(self): + return [ + entities.EnableStage.PRE, + ] + + async def process(self, message: str) -> entities.FilterResult: + if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']: + if message.startswith(rule): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + ) + + if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']: + if re.search(rule, message): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + ) + + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py new file mode 100644 index 00000000..e687c082 --- /dev/null +++ b/pkg/pipeline/entities.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import enum +import typing + +import pydantic +import mirai +import mirai.models.message as mirai_message + +from ..core import entities + + +class ResultType(enum.Enum): + + CONTINUE = enum.auto() + """继续流水线""" + + INTERRUPT = enum.auto() + """中断流水线""" + + +class StageProcessResult(pydantic.BaseModel): + + result_type: ResultType + + new_query: entities.Query + + user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给用户""" + + admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给管理员""" + + console_notice: typing.Optional[str] = '' + """只要设置了就会输出到控制台""" + + debug_notice: typing.Optional[str] = '' + diff --git a/pkg/pipeline/longtext/__init__.py b/pkg/pipeline/longtext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py new file mode 100644 index 00000000..11144891 --- /dev/null +++ b/pkg/pipeline/longtext/longtext.py @@ -0,0 +1,57 @@ +from __future__ import annotations +import os +import traceback + +from PIL import Image, ImageDraw, ImageFont +from mirai.models.message import MessageComponent, Plain, MessageChain + +from ...core import app +from . import strategy +from .strategies import image, forward +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("LongTextProcessStage") +class LongTextProcessStage(stage.PipelineStage): + + strategy_impl: strategy.LongTextStrategy + + async def initialize(self): + config = self.ap.cfg_mgr.data + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + use_font = config['font_path'] + try: + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + self.ap.logger.info("使用Windows自带字体:" + use_font) + self.ap.cfg_mgr.data['font_path'] = use_font + else: + self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + except: + traceback.print_exc() + self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + self.strategy_impl = image.Text2ImageStrategy(self.ap) + elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward': + self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + await self.strategy_impl.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']: + query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/longtext/strategies/__init__.py b/pkg/pipeline/longtext/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py new file mode 100644 index 00000000..d1b5c36c --- /dev/null +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -0,0 +1,62 @@ +# 转发消息组件 +from __future__ import annotations +import typing + +from mirai.models import MessageChain +from mirai.models.message import MessageComponent, ForwardMessageNode +from mirai.models.base import MiraiBaseModel + +from .. import strategy as strategy_model + + +class ForwardMessageDiaplay(MiraiBaseModel): + title: str = "群聊的聊天记录" + brief: str = "[聊天记录]" + source: str = "聊天记录" + preview: typing.List[str] = [] + summary: str = "查看x条转发消息" + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + display: ForwardMessageDiaplay + """显示信息""" + node_list: typing.List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +class ForwardComponentStrategy(strategy_model.LongTextStrategy): + + async def process(self, message: str) -> list[MessageComponent]: + display = ForwardMessageDiaplay( + title="群聊的聊天记录", + brief="[聊天记录]", + source="聊天记录", + preview=["QQ用户: "+message], + summary="查看1条转发消息" + ) + + node_list = [ + ForwardMessageNode( + sender_id=self.ap.im_mgr.bot_account_id, + sender_name='QQ用户', + message_chain=MessageChain([message]) + ) + ] + + forward = Forward( + display=display, + node_list=node_list + ) + + return [forward] diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py new file mode 100644 index 00000000..4f789098 --- /dev/null +++ b/pkg/pipeline/longtext/strategies/image.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import typing +import os +import base64 +import time +import re + +from PIL import Image, ImageDraw, ImageFont + +from mirai.models import MessageChain, Image as ImageComponent +from mirai.models.message import MessageComponent + +from .. import strategy as strategy_model + + +class Text2ImageStrategy(strategy_model.LongTextStrategy): + + text_render_font: ImageFont.FreeTypeFont + + async def initialize(self): + self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8") + + async def process(self, message: str) -> list[MessageComponent]: + img_path = self.text_to_image( + text_str=message, + save_as='temp/{}.png'.format(int(time.time())) + ) + + compressed_path, size = self.compress_image( + img_path, + outfile="temp/{}_compressed.png".format(int(time.time())) + ) + + with open(compressed_path, 'rb') as f: + img = f.read() + + b64 = base64.b64encode(img) + + # 删除图片 + os.remove(img_path) + + if os.path.exists(compressed_path): + os.remove(compressed_path) + + return [ + ImageComponent( + base64=b64.decode('utf-8'), + ) + ] + + def indexNumber(self, path=''): + """ + 查找字符串中数字所在串中的位置 + :param path:目标字符串 + :return:: : [['1', 16], ['2', 35], ['1', 51]] + """ + kv = [] + nums = [] + beforeDatas = re.findall('[\d]+', path) + for num in beforeDatas: + indexV = [] + times = path.count(num) + if times > 1: + if num not in nums: + indexs = re.finditer(num, path) + for index in indexs: + iV = [] + i = index.span()[0] + iV.append(num) + iV.append(i) + kv.append(iV) + nums.append(num) + else: + index = path.find(num) + indexV.append(num) + indexV.append(index) + kv.append(indexV) + # 根据数字位置排序 + indexSort = [] + resultIndex = [] + for vi in kv: + indexSort.append(vi[1]) + indexSort.sort() + for i in indexSort: + for v in kv: + if i == v[1]: + resultIndex.append(v) + return resultIndex + + + def get_size(self, file): + # 获取文件大小:KB + size = os.path.getsize(file) + return size / 1024 + + + def get_outfile(self, infile, outfile): + if outfile: + return outfile + dir, suffix = os.path.splitext(infile) + outfile = '{}-out{}'.format(dir, suffix) + return outfile + + + def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): + """不改变图片尺寸压缩到指定大小 + :param infile: 压缩源文件 + :param outfile: 压缩文件保存地址 + :param mb: 压缩目标,KB + :param step: 每次调整的压缩比率 + :param quality: 初始压缩比率 + :return: 压缩文件地址,压缩文件大小 + """ + o_size = self.get_size(infile) + if o_size <= kb: + return infile, o_size + outfile = self.get_outfile(infile, outfile) + while o_size > kb: + im = Image.open(infile) + im.save(outfile, quality=quality) + if quality - step < 0: + break + quality -= step + o_size = self.get_size(outfile) + return outfile, self.get_size(outfile) + + + def text_to_image(self, text_str: str, save_as="temp.png", width=800): + + text_str = text_str.replace("\t", " ") + + # 分行 + lines = text_str.split('\n') + + # 计算并分割 + final_lines = [] + + text_width = width-80 + + self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) + for line in lines: + # 如果长了就分割 + line_width = self.text_render_font.getlength(line) + self.ap.logger.debug("line_width: {}".format(line_width)) + if line_width < text_width: + final_lines.append(line) + continue + else: + rest_text = line + while True: + # 分割最前面的一行 + point = int(len(rest_text) * (text_width / line_width)) + + # 检查断点是否在数字中间 + numbers = self.indexNumber(rest_text) + + for number in numbers: + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + point = number[1] + break + + final_lines.append(rest_text[:point]) + rest_text = rest_text[point:] + line_width = self.text_render_font.getlength(rest_text) + if line_width < text_width: + final_lines.append(rest_text) + break + else: + continue + # 准备画布 + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) + draw = ImageDraw.Draw(img, mode='RGBA') + + self.ap.logger.debug("正在绘制图片...") + # 绘制正文 + line_number = 0 + offset_x = 20 + offset_y = 30 + for final_line in final_lines: + draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) + # 遍历此行,检查是否有emoji + idx_in_line = 0 + for ch in final_line: + # 检查字符占位宽 + char_code = ord(ch) + if char_code >= 127: + idx_in_line += 1 + else: + idx_in_line += 0.5 + + line_number += 1 + + self.ap.logger.debug("正在保存图片...") + img.save(save_as) + + return save_as diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py new file mode 100644 index 00000000..5c6bfb9c --- /dev/null +++ b/pkg/pipeline/longtext/strategy.py @@ -0,0 +1,22 @@ +from __future__ import annotations +import abc +import typing + +import mirai +from mirai.models.message import MessageComponent + +from ...core import app + + +class LongTextStrategy(metaclass=abc.ABCMeta): + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def process(self, message: str) -> list[MessageComponent]: + return [] diff --git a/pkg/pipeline/resprule/__init__.py b/pkg/pipeline/resprule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py new file mode 100644 index 00000000..ffee3081 --- /dev/null +++ b/pkg/pipeline/resprule/entities.py @@ -0,0 +1,9 @@ +import pydantic +import mirai + + +class RuleJudgeResult(pydantic.BaseModel): + + matching: bool = False + + replacement: mirai.MessageChain = None diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py new file mode 100644 index 00000000..6335a7d4 --- /dev/null +++ b/pkg/pipeline/resprule/resprule.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import mirai + +from ...core import app +from . import entities as rule_entities, rule +from .rules import atbot, prefix, regexp, random + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("GroupRespondRuleCheckStage") +class GroupRespondRuleCheckStage(stage.PipelineStage): + """群组响应规则检查器 + """ + + rule_matchers: list[rule.GroupRespondRule] + + async def initialize(self): + """初始化检查器 + """ + self.rule_matchers = [ + atbot.AtBotRule(self.ap), + prefix.PrefixRule(self.ap), + regexp.RegExpRule(self.ap), + random.RandomRespRule(self.ap), + ] + + for rule_matcher in self.rule_matchers: + await rule_matcher.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + + if query.launcher_type != 'group': + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + rules = self.ap.cfg_mgr.data['response_rules'] + + use_rule = rules['default'] + + if str(query.launcher_id) in use_rule: + use_rule = use_rule[str(query.launcher_id)] + + for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 + res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule) + if res.matching: + query.message_chain = res.replacement + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py new file mode 100644 index 00000000..e530d063 --- /dev/null +++ b/pkg/pipeline/resprule/rule.py @@ -0,0 +1,31 @@ +from __future__ import annotations +import abc + +import mirai + +from ...core import app +from . import entities + + +class GroupRespondRule(metaclass=abc.ABCMeta): + """群组响应规则的抽象类 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + """判断消息是否匹配规则 + """ + raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/__init__.py b/pkg/pipeline/resprule/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py new file mode 100644 index 00000000..eefc4891 --- /dev/null +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class AtBotRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + + if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']: + message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id)) + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement = message_chain + ) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py new file mode 100644 index 00000000..31ead5ab --- /dev/null +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -0,0 +1,29 @@ +import mirai + +from .. import rule as rule_model +from .. import entities + + +class PrefixRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + prefixes = rule_dict['prefix'] + + for prefix in prefixes: + if message_text.startswith(prefix): + return entities.RuleJudgeResult( + matching=True, + replacement=mirai.MessageChain([ + mirai.Plain(message_text[len(prefix):]) + ]), + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py new file mode 100644 index 00000000..1e8354b5 --- /dev/null +++ b/pkg/pipeline/resprule/rules/random.py @@ -0,0 +1,22 @@ +import random + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RandomRespRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + random_rate = rule_dict['random_rate'] + + return entities.RuleJudgeResult( + matching=random.random() < random_rate, + replacement=message_chain + ) \ No newline at end of file diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py new file mode 100644 index 00000000..0d621fe4 --- /dev/null +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -0,0 +1,31 @@ +import re + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RegExpRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + regexps = rule_dict['regexp'] + + for regexp in regexps: + match = re.match(regexp, message_text) + + if match: + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py new file mode 100644 index 00000000..84a0339d --- /dev/null +++ b/pkg/pipeline/stage.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import abc + +from ..core import app, entities as core_entities +from . import entities + + +_stage_classes: dict[str, PipelineStage] = {} + + +def stage_class(name: str): + + def decorator(cls): + _stage_classes[name] = cls + return cls + + return decorator + + +class PipelineStage(metaclass=abc.ABCMeta): + """流水线阶段 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化 + """ + pass + + @abc.abstractmethod + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + raise NotImplementedError diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py new file mode 100644 index 00000000..f5407a2e --- /dev/null +++ b/pkg/pipeline/stagemgr.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import pydantic + +from ..core import app +from . import stage +from .resprule import resprule +from .bansess import bansess +from .cntfilter import cntfilter +from .longtext import longtext + + +class StageInstContainer(): + """阶段实例容器 + """ + + inst_name: str + + inst: stage.PipelineStage + + def __init__(self, inst_name: str, inst: stage.PipelineStage): + self.inst_name = inst_name + self.inst = inst + + +class StageManager: + ap: app.Application + + stage_containers: list[StageInstContainer] + + def __init__(self, ap: app.Application): + self.ap = ap + + self.stage_containers = [] + + async def initialize(self): + """初始化 + """ + + for name, cls in stage._stage_classes.items(): + self.stage_containers.append(StageInstContainer( + inst_name=name, + inst=cls(self.ap) + )) + + for stage_containers in self.stage_containers: + await stage_containers.inst.initialize() diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py index d8ef4958..74ffd3f7 100644 --- a/pkg/qqbot/bansess/bansess.py +++ b/pkg/qqbot/bansess/bansess.py @@ -3,7 +3,7 @@ from __future__ import annotations import re -from ...boot import app +from ...core import app from ...config import manager as cfg_mgr diff --git a/pkg/qqbot/cntfilter/cntfilter.py b/pkg/qqbot/cntfilter/cntfilter.py index 2d690b57..4c7305c0 100644 --- a/pkg/qqbot/cntfilter/cntfilter.py +++ b/pkg/qqbot/cntfilter/cntfilter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ...boot import app +from ...core import app from . import entities from . import filter from .filters import cntignore, banwords, baiduexamine diff --git a/pkg/qqbot/cntfilter/filter.py b/pkg/qqbot/cntfilter/filter.py index 4d4cd79f..57792145 100644 --- a/pkg/qqbot/cntfilter/filter.py +++ b/pkg/qqbot/cntfilter/filter.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc -from ...boot import app +from ...core import app from . import entities diff --git a/pkg/qqbot/longtext/longtext.py b/pkg/qqbot/longtext/longtext.py index 21267880..697f65e4 100644 --- a/pkg/qqbot/longtext/longtext.py +++ b/pkg/qqbot/longtext/longtext.py @@ -5,7 +5,7 @@ from PIL import Image, ImageDraw, ImageFont from mirai.models.message import MessageComponent, Plain -from ...boot import app +from ...core import app from . import strategy from .strategies import image, forward diff --git a/pkg/qqbot/longtext/strategy.py b/pkg/qqbot/longtext/strategy.py index ef4cc1a5..5c6bfb9c 100644 --- a/pkg/qqbot/longtext/strategy.py +++ b/pkg/qqbot/longtext/strategy.py @@ -5,7 +5,7 @@ import mirai from mirai.models.message import MessageComponent -from ...boot import app +from ...core import app class LongTextStrategy(metaclass=abc.ABCMeta): diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 5239604f..b16450e8 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -24,7 +24,7 @@ from .longtext import longtext from .ratelim import ratelim -from ..boot import app +from ..core import app, entities as core_entities # 控制QQ消息输入输出的类 @@ -91,45 +91,29 @@ async def initialize(self): # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 async def on_friend_message(event: FriendMessage): - async def friend_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - await self.on_person_message(event) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) - asyncio.create_task(friend_message_handler()) self.adapter.register_listener( FriendMessage, on_friend_message ) async def on_stranger_message(event: StrangerMessage): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) - async def stranger_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - await self.on_person_message(event) - - asyncio.create_task(stranger_message_handler()) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 if config['msg_source_adapter'] == 'yirimirai': self.adapter.register_listener( @@ -139,49 +123,19 @@ async def stranger_message_handler(): async def on_group_message(event: GroupMessage): - async def group_message_handler(event: GroupMessage): - # 触发事件 - args = { - "launcher_type": "group", - "launcher_id": event.group.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - await self.on_group_message(event) - - asyncio.create_task(group_message_handler(event)) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) self.adapter.register_listener( GroupMessage, on_group_message ) - def unsubscribe_all(): - """取消所有订阅 - - 用于在热重载流程中卸载所有事件处理器 - """ - self.adapter.unregister_listener( - FriendMessage, - on_friend_message - ) - if config['msg_source_adapter'] == 'yirimirai': - self.adapter.unregister_listener( - StrangerMessage, - on_stranger_message - ) - self.adapter.unregister_listener( - GroupMessage, - on_group_message - ) - - self.unsubscribe_all = unsubscribe_all - async def send(self, event, msg, check_quote=True, check_at_sender=True): config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index e1673583..65de8d52 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -14,7 +14,7 @@ from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom -from ..boot import app +from ..core import app from .cntfilter import entities processing = [] diff --git a/pkg/qqbot/ratelim/algo.py b/pkg/qqbot/ratelim/algo.py index 10bbdd3a..b6d9ba7b 100644 --- a/pkg/qqbot/ratelim/algo.py +++ b/pkg/qqbot/ratelim/algo.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from ...boot import app +from ...core import app class ReteLimitAlgo(metaclass=abc.ABCMeta): diff --git a/pkg/qqbot/ratelim/ratelim.py b/pkg/qqbot/ratelim/ratelim.py index ab23d714..68fe0316 100644 --- a/pkg/qqbot/ratelim/ratelim.py +++ b/pkg/qqbot/ratelim/ratelim.py @@ -2,7 +2,7 @@ from . import algo from .algos import fixedwin -from ...boot import app +from ...core import app class RateLimiter: diff --git a/pkg/qqbot/resprule/resprule.py b/pkg/qqbot/resprule/resprule.py index f0c51921..9ea8321d 100644 --- a/pkg/qqbot/resprule/resprule.py +++ b/pkg/qqbot/resprule/resprule.py @@ -2,7 +2,7 @@ import mirai -from ...boot import app +from ...core import app from . import entities, rule from .rules import atbot, prefix, regexp, random diff --git a/pkg/qqbot/resprule/rule.py b/pkg/qqbot/resprule/rule.py index 67af0204..e530d063 100644 --- a/pkg/qqbot/resprule/rule.py +++ b/pkg/qqbot/resprule/rule.py @@ -3,7 +3,7 @@ import mirai -from ...boot import app +from ...core import app from . import entities diff --git a/start.py b/start.py index f22012ee..b56ea9e9 100644 --- a/start.py +++ b/start.py @@ -1,6 +1,6 @@ import asyncio -from pkg.boot import boot +from pkg.core import boot if __name__ == '__main__': From 1900ddacbbb7bab535e0da71af8dd54061769642 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 26 Jan 2024 15:54:24 +0800 Subject: [PATCH 02/10] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=20qqbot=20?= =?UTF-8?q?=E5=8C=85=E4=B8=AD=E7=9A=84=E6=B5=81=E7=A8=8B=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/bansess/__init__.py | 0 pkg/qqbot/bansess/bansess.py | 70 ------- pkg/qqbot/cntfilter/__init__.py | 0 pkg/qqbot/cntfilter/cntfilter.py | 93 --------- pkg/qqbot/cntfilter/entities.py | 64 ------- pkg/qqbot/cntfilter/filter.py | 34 ---- pkg/qqbot/cntfilter/filters/__init__.py | 0 pkg/qqbot/cntfilter/filters/baiduexamine.py | 61 ------ pkg/qqbot/cntfilter/filters/banwords.py | 44 ----- pkg/qqbot/cntfilter/filters/cntignore.py | 43 ----- pkg/qqbot/longtext/__init__.py | 0 pkg/qqbot/longtext/longtext.py | 56 ------ pkg/qqbot/longtext/strategies/__init__.py | 0 pkg/qqbot/longtext/strategies/forward.py | 62 ------ pkg/qqbot/longtext/strategies/image.py | 197 -------------------- pkg/qqbot/longtext/strategy.py | 22 --- pkg/qqbot/resprule/__init__.py | 0 pkg/qqbot/resprule/entities.py | 9 - pkg/qqbot/resprule/resprule.py | 58 ------ pkg/qqbot/resprule/rule.py | 31 --- pkg/qqbot/resprule/rules/__init__.py | 0 pkg/qqbot/resprule/rules/atbot.py | 28 --- pkg/qqbot/resprule/rules/prefix.py | 29 --- pkg/qqbot/resprule/rules/random.py | 22 --- pkg/qqbot/resprule/rules/regexp.py | 31 --- 25 files changed, 954 deletions(-) delete mode 100644 pkg/qqbot/bansess/__init__.py delete mode 100644 pkg/qqbot/bansess/bansess.py delete mode 100644 pkg/qqbot/cntfilter/__init__.py delete mode 100644 pkg/qqbot/cntfilter/cntfilter.py delete mode 100644 pkg/qqbot/cntfilter/entities.py delete mode 100644 pkg/qqbot/cntfilter/filter.py delete mode 100644 pkg/qqbot/cntfilter/filters/__init__.py delete mode 100644 pkg/qqbot/cntfilter/filters/baiduexamine.py delete mode 100644 pkg/qqbot/cntfilter/filters/banwords.py delete mode 100644 pkg/qqbot/cntfilter/filters/cntignore.py delete mode 100644 pkg/qqbot/longtext/__init__.py delete mode 100644 pkg/qqbot/longtext/longtext.py delete mode 100644 pkg/qqbot/longtext/strategies/__init__.py delete mode 100644 pkg/qqbot/longtext/strategies/forward.py delete mode 100644 pkg/qqbot/longtext/strategies/image.py delete mode 100644 pkg/qqbot/longtext/strategy.py delete mode 100644 pkg/qqbot/resprule/__init__.py delete mode 100644 pkg/qqbot/resprule/entities.py delete mode 100644 pkg/qqbot/resprule/resprule.py delete mode 100644 pkg/qqbot/resprule/rule.py delete mode 100644 pkg/qqbot/resprule/rules/__init__.py delete mode 100644 pkg/qqbot/resprule/rules/atbot.py delete mode 100644 pkg/qqbot/resprule/rules/prefix.py delete mode 100644 pkg/qqbot/resprule/rules/random.py delete mode 100644 pkg/qqbot/resprule/rules/regexp.py diff --git a/pkg/qqbot/bansess/__init__.py b/pkg/qqbot/bansess/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py deleted file mode 100644 index 74ffd3f7..00000000 --- a/pkg/qqbot/bansess/bansess.py +++ /dev/null @@ -1,70 +0,0 @@ -# 处理对会话的禁用配置 -# 过去的 banlist -from __future__ import annotations -import re - -from ...core import app -from ...config import manager as cfg_mgr - - -class SessionBanManager: - - ap: app.Application = None - - banlist_mgr: cfg_mgr.ConfigManager - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - self.banlist_mgr = await cfg_mgr.load_python_module_config( - "banlist.py", - "res/templates/banlist-template.py" - ) - - async def is_banned( - self, launcher_type: str, launcher_id: int, sender_id: int - ) -> bool: - if not self.banlist_mgr.data['enable']: - return False - - result = False - - if launcher_type == 'group': - if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应 - result = True - # 检查是否显式声明发起人QQ要被person忽略 - elif sender_id in self.banlist_mgr.data['person']: - result = True - else: - for group_rule in self.banlist_mgr.data['group']: - if type(group_rule) == int: - if group_rule == launcher_id: - result = True - elif type(group_rule) == str: - if group_rule.startswith('!'): - reg_str = group_rule[1:] - if re.match(reg_str, str(launcher_id)): - result = False - break - else: - if re.match(group_rule, str(launcher_id)): - result = True - elif launcher_type == 'person': - if not self.banlist_mgr.data['enable_private']: - result = True - else: - for person_rule in self.banlist_mgr.data['person']: - if type(person_rule) == int: - if person_rule == launcher_id: - result = True - elif type(person_rule) == str: - if person_rule.startswith('!'): - reg_str = person_rule[1:] - if re.match(reg_str, str(launcher_id)): - result = False - break - else: - if re.match(person_rule, str(launcher_id)): - result = True - return result diff --git a/pkg/qqbot/cntfilter/__init__.py b/pkg/qqbot/cntfilter/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cntfilter/cntfilter.py b/pkg/qqbot/cntfilter/cntfilter.py deleted file mode 100644 index 4c7305c0..00000000 --- a/pkg/qqbot/cntfilter/cntfilter.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -from ...core import app -from . import entities -from . import filter -from .filters import cntignore, banwords, baiduexamine - - -class ContentFilterManager: - - ao: app.Application - - filter_chain: list[filter.ContentFilter] - - def __init__(self, ap: app.Application) -> None: - self.ap = ap - self.filter_chain = [] - - async def initialize(self): - self.filter_chain.append(cntignore.ContentIgnore(self.ap)) - - if self.ap.cfg_mgr.data['sensitive_word_filter']: - self.filter_chain.append(banwords.BanWordFilter(self.ap)) - - if self.ap.cfg_mgr.data['baidu_check']: - self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) - - for filter in self.filter_chain: - await filter.initialize() - - async def pre_process(self, message: str) -> entities.FilterManagerResult: - """请求llm前处理消息 - 只要有一个不通过就不放行,只放行 PASS 的消息 - """ - if not self.ap.cfg_mgr.data['income_msg_check']: # 不检查收到的消息,直接放行 - return entities.FilterManagerResult( - level=entities.ManagerResultLevel.CONTINUE, - replacement=message, - user_notice='', - console_notice='' - ) - else: - for filter in self.filter_chain: - if entities.EnableStage.PRE in filter.enable_stages: - result = await filter.process(message) - - if result.level in [ - entities.ResultLevel.BLOCK, - entities.ResultLevel.MASKED - ]: - return entities.FilterManagerResult( - level=entities.ManagerResultLevel.INTERRUPT, - replacement=result.replacement, - user_notice=result.user_notice, - console_notice=result.console_notice - ) - elif result.level == entities.ResultLevel.PASS: - message = result.replacement - - return entities.FilterManagerResult( - level=entities.ManagerResultLevel.CONTINUE, - replacement=message, - user_notice='', - console_notice='' - ) - - async def post_process(self, message: str) -> entities.FilterManagerResult: - """请求llm后处理响应 - 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter - """ - for filter in self.filter_chain: - if entities.EnableStage.POST in filter.enable_stages: - result = await filter.process(message) - - if result.level == entities.ResultLevel.BLOCK: - return entities.FilterManagerResult( - level=entities.ManagerResultLevel.INTERRUPT, - replacement=result.replacement, - user_notice=result.user_notice, - console_notice=result.console_notice - ) - elif result.level in [ - entities.ResultLevel.PASS, - entities.ResultLevel.MASKED - ]: - message = result.replacement - - return entities.FilterManagerResult( - level=entities.ManagerResultLevel.CONTINUE, - replacement=message, - user_notice='', - console_notice='' - ) diff --git a/pkg/qqbot/cntfilter/entities.py b/pkg/qqbot/cntfilter/entities.py deleted file mode 100644 index 7ab05675..00000000 --- a/pkg/qqbot/cntfilter/entities.py +++ /dev/null @@ -1,64 +0,0 @@ - -import typing -import enum - -import pydantic - - -class ResultLevel(enum.Enum): - """结果等级""" - PASS = enum.auto() - """通过""" - - WARN = enum.auto() - """警告""" - - MASKED = enum.auto() - """已掩去""" - - BLOCK = enum.auto() - """阻止""" - - -class EnableStage(enum.Enum): - """启用阶段""" - PRE = enum.auto() - """预处理""" - - POST = enum.auto() - """后处理""" - - -class FilterResult(pydantic.BaseModel): - level: ResultLevel - - replacement: str - """替换后的消息""" - - user_notice: str - """不通过时,用户提示消息""" - - console_notice: str - """不通过时,控制台提示消息""" - - -class ManagerResultLevel(enum.Enum): - """处理器结果等级""" - CONTINUE = enum.auto() - """继续""" - - INTERRUPT = enum.auto() - """中断""" - -class FilterManagerResult(pydantic.BaseModel): - - level: ManagerResultLevel - - replacement: str - """替换后的消息""" - - user_notice: str - """用户提示消息""" - - console_notice: str - """控制台提示消息""" diff --git a/pkg/qqbot/cntfilter/filter.py b/pkg/qqbot/cntfilter/filter.py deleted file mode 100644 index 57792145..00000000 --- a/pkg/qqbot/cntfilter/filter.py +++ /dev/null @@ -1,34 +0,0 @@ -# 内容过滤器的抽象类 -from __future__ import annotations -import abc - -from ...core import app -from . import entities - - -class ContentFilter(metaclass=abc.ABCMeta): - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - @property - def enable_stages(self): - """启用的阶段 - """ - return [ - entities.EnableStage.PRE, - entities.EnableStage.POST - ] - - async def initialize(self): - """初始化过滤器 - """ - pass - - @abc.abstractmethod - async def process(self, message: str) -> entities.FilterResult: - """处理消息 - """ - raise NotImplementedError diff --git a/pkg/qqbot/cntfilter/filters/__init__.py b/pkg/qqbot/cntfilter/filters/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cntfilter/filters/baiduexamine.py b/pkg/qqbot/cntfilter/filters/baiduexamine.py deleted file mode 100644 index a658897b..00000000 --- a/pkg/qqbot/cntfilter/filters/baiduexamine.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import aiohttp - -from .. import entities -from .. import filter as filter_model - - -BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" -BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" - - -class BaiduCloudExamine(filter_model.ContentFilter): - """百度云内容审核""" - - async def _get_token(self) -> str: - async with aiohttp.ClientSession() as session: - async with session.post( - BAIDU_EXAMINE_TOKEN_URL, - params={ - "grant_type": "client_credentials", - "client_id": self.ap.cfg_mgr.data['baidu_api_key'], - "client_secret": self.ap.cfg_mgr.data['baidu_secret_key'] - } - ) as resp: - return (await resp.json())['access_token'] - - async def process(self, message: str) -> entities.FilterResult: - - async with aiohttp.ClientSession() as session: - async with session.post( - BAIDU_EXAMINE_URL.format(await self._get_token()), - headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, - data=f"text={message}".encode('utf-8') - ) as resp: - result = await resp.json() - - if "error_code" in result: - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement=message, - user_notice='', - console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" - ) - else: - conclusion = result["conclusion"] - - if conclusion in ("合规"): - return entities.FilterResult( - level=entities.ResultLevel.PASS, - replacement=message, - user_notice='', - console_notice=f"百度云判定结果:{conclusion}" - ) - else: - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement=message, - user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'], - console_notice=f"百度云判定结果:{conclusion}" - ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/banwords.py b/pkg/qqbot/cntfilter/filters/banwords.py deleted file mode 100644 index 9451c7b8..00000000 --- a/pkg/qqbot/cntfilter/filters/banwords.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations -import re - -from .. import filter as filter_model -from .. import entities -from ....config import manager as cfg_mgr - - -class BanWordFilter(filter_model.ContentFilter): - """根据内容禁言""" - - sensitive: cfg_mgr.ConfigManager - - async def initialize(self): - self.sensitive = await cfg_mgr.load_json_config( - "sensitive.json", - "res/templates/sensitive-template.json" - ) - - async def process(self, message: str) -> entities.FilterResult: - found = False - - for word in self.sensitive.data['words']: - match = re.findall(word, message) - - if len(match) > 0: - found = True - - for i in range(len(match)): - if self.sensitive.data['mask_word'] == "": - message = message.replace( - match[i], self.sensitive.data['mask'] * len(match[i]) - ) - else: - message = message.replace( - match[i], self.sensitive.data['mask_word'] - ) - - return entities.FilterResult( - level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, - replacement=message, - user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '', - console_notice='' - ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/cntignore.py b/pkg/qqbot/cntfilter/filters/cntignore.py deleted file mode 100644 index 81408868..00000000 --- a/pkg/qqbot/cntfilter/filters/cntignore.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations -import re - -from .. import entities -from .. import filter as filter_model - - -class ContentIgnore(filter_model.ContentFilter): - """根据内容忽略消息""" - - @property - def enable_stages(self): - return [ - entities.EnableStage.PRE, - ] - - async def process(self, message: str) -> entities.FilterResult: - if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']: - for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']: - if message.startswith(rule): - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement='', - user_notice='', - console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' - ) - - if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']: - for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']: - if re.search(rule, message): - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement='', - user_notice='', - console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' - ) - - return entities.FilterResult( - level=entities.ResultLevel.PASS, - replacement=message, - user_notice='', - console_notice='' - ) \ No newline at end of file diff --git a/pkg/qqbot/longtext/__init__.py b/pkg/qqbot/longtext/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/longtext/longtext.py b/pkg/qqbot/longtext/longtext.py deleted file mode 100644 index 697f65e4..00000000 --- a/pkg/qqbot/longtext/longtext.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations -import os -import traceback - -from PIL import Image, ImageDraw, ImageFont -from mirai.models.message import MessageComponent, Plain - -from ...core import app -from . import strategy -from .strategies import image, forward - - -class LongTextProcessor: - - ap: app.Application - - strategy_impl: strategy.LongTextStrategy - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - config = self.ap.cfg_mgr.data - if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': - use_font = config['font_path'] - try: - # 检查是否存在 - if not os.path.exists(use_font): - # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" - if not os.path.exists(use_font): - self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config['blob_message_strategy'] = "forward" - else: - self.ap.logger.info("使用Windows自带字体:" + use_font) - self.ap.cfg_mgr.data['font_path'] = use_font - else: - self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" - except: - traceback.print_exc() - self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) - self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" - - if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': - self.strategy_impl = image.Text2ImageStrategy(self.ap) - elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward': - self.strategy_impl = forward.ForwardComponentStrategy(self.ap) - await self.strategy_impl.initialize() - - async def check_and_process(self, message: str) -> list[MessageComponent]: - if len(message) > self.ap.cfg_mgr.data['blob_message_threshold']: - return await self.strategy_impl.process(message) - else: - return [Plain(message)] \ No newline at end of file diff --git a/pkg/qqbot/longtext/strategies/__init__.py b/pkg/qqbot/longtext/strategies/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/longtext/strategies/forward.py b/pkg/qqbot/longtext/strategies/forward.py deleted file mode 100644 index d1b5c36c..00000000 --- a/pkg/qqbot/longtext/strategies/forward.py +++ /dev/null @@ -1,62 +0,0 @@ -# 转发消息组件 -from __future__ import annotations -import typing - -from mirai.models import MessageChain -from mirai.models.message import MessageComponent, ForwardMessageNode -from mirai.models.base import MiraiBaseModel - -from .. import strategy as strategy_model - - -class ForwardMessageDiaplay(MiraiBaseModel): - title: str = "群聊的聊天记录" - brief: str = "[聊天记录]" - source: str = "聊天记录" - preview: typing.List[str] = [] - summary: str = "查看x条转发消息" - - -class Forward(MessageComponent): - """合并转发。""" - type: str = "Forward" - """消息组件类型。""" - display: ForwardMessageDiaplay - """显示信息""" - node_list: typing.List[ForwardMessageNode] - """转发消息节点列表。""" - def __init__(self, *args, **kwargs): - if len(args) == 1: - self.node_list = args[0] - super().__init__(**kwargs) - super().__init__(*args, **kwargs) - - def __str__(self): - return '[聊天记录]' - - -class ForwardComponentStrategy(strategy_model.LongTextStrategy): - - async def process(self, message: str) -> list[MessageComponent]: - display = ForwardMessageDiaplay( - title="群聊的聊天记录", - brief="[聊天记录]", - source="聊天记录", - preview=["QQ用户: "+message], - summary="查看1条转发消息" - ) - - node_list = [ - ForwardMessageNode( - sender_id=self.ap.im_mgr.bot_account_id, - sender_name='QQ用户', - message_chain=MessageChain([message]) - ) - ] - - forward = Forward( - display=display, - node_list=node_list - ) - - return [forward] diff --git a/pkg/qqbot/longtext/strategies/image.py b/pkg/qqbot/longtext/strategies/image.py deleted file mode 100644 index 4f789098..00000000 --- a/pkg/qqbot/longtext/strategies/image.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -import typing -import os -import base64 -import time -import re - -from PIL import Image, ImageDraw, ImageFont - -from mirai.models import MessageChain, Image as ImageComponent -from mirai.models.message import MessageComponent - -from .. import strategy as strategy_model - - -class Text2ImageStrategy(strategy_model.LongTextStrategy): - - text_render_font: ImageFont.FreeTypeFont - - async def initialize(self): - self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8") - - async def process(self, message: str) -> list[MessageComponent]: - img_path = self.text_to_image( - text_str=message, - save_as='temp/{}.png'.format(int(time.time())) - ) - - compressed_path, size = self.compress_image( - img_path, - outfile="temp/{}_compressed.png".format(int(time.time())) - ) - - with open(compressed_path, 'rb') as f: - img = f.read() - - b64 = base64.b64encode(img) - - # 删除图片 - os.remove(img_path) - - if os.path.exists(compressed_path): - os.remove(compressed_path) - - return [ - ImageComponent( - base64=b64.decode('utf-8'), - ) - ] - - def indexNumber(self, path=''): - """ - 查找字符串中数字所在串中的位置 - :param path:目标字符串 - :return:: : [['1', 16], ['2', 35], ['1', 51]] - """ - kv = [] - nums = [] - beforeDatas = re.findall('[\d]+', path) - for num in beforeDatas: - indexV = [] - times = path.count(num) - if times > 1: - if num not in nums: - indexs = re.finditer(num, path) - for index in indexs: - iV = [] - i = index.span()[0] - iV.append(num) - iV.append(i) - kv.append(iV) - nums.append(num) - else: - index = path.find(num) - indexV.append(num) - indexV.append(index) - kv.append(indexV) - # 根据数字位置排序 - indexSort = [] - resultIndex = [] - for vi in kv: - indexSort.append(vi[1]) - indexSort.sort() - for i in indexSort: - for v in kv: - if i == v[1]: - resultIndex.append(v) - return resultIndex - - - def get_size(self, file): - # 获取文件大小:KB - size = os.path.getsize(file) - return size / 1024 - - - def get_outfile(self, infile, outfile): - if outfile: - return outfile - dir, suffix = os.path.splitext(infile) - outfile = '{}-out{}'.format(dir, suffix) - return outfile - - - def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): - """不改变图片尺寸压缩到指定大小 - :param infile: 压缩源文件 - :param outfile: 压缩文件保存地址 - :param mb: 压缩目标,KB - :param step: 每次调整的压缩比率 - :param quality: 初始压缩比率 - :return: 压缩文件地址,压缩文件大小 - """ - o_size = self.get_size(infile) - if o_size <= kb: - return infile, o_size - outfile = self.get_outfile(infile, outfile) - while o_size > kb: - im = Image.open(infile) - im.save(outfile, quality=quality) - if quality - step < 0: - break - quality -= step - o_size = self.get_size(outfile) - return outfile, self.get_size(outfile) - - - def text_to_image(self, text_str: str, save_as="temp.png", width=800): - - text_str = text_str.replace("\t", " ") - - # 分行 - lines = text_str.split('\n') - - # 计算并分割 - final_lines = [] - - text_width = width-80 - - self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) - for line in lines: - # 如果长了就分割 - line_width = self.text_render_font.getlength(line) - self.ap.logger.debug("line_width: {}".format(line_width)) - if line_width < text_width: - final_lines.append(line) - continue - else: - rest_text = line - while True: - # 分割最前面的一行 - point = int(len(rest_text) * (text_width / line_width)) - - # 检查断点是否在数字中间 - numbers = self.indexNumber(rest_text) - - for number in numbers: - if number[1] < point < number[1] + len(number[0]) and number[1] != 0: - point = number[1] - break - - final_lines.append(rest_text[:point]) - rest_text = rest_text[point:] - line_width = self.text_render_font.getlength(rest_text) - if line_width < text_width: - final_lines.append(rest_text) - break - else: - continue - # 准备画布 - img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) - draw = ImageDraw.Draw(img, mode='RGBA') - - self.ap.logger.debug("正在绘制图片...") - # 绘制正文 - line_number = 0 - offset_x = 20 - offset_y = 30 - for final_line in final_lines: - draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) - # 遍历此行,检查是否有emoji - idx_in_line = 0 - for ch in final_line: - # 检查字符占位宽 - char_code = ord(ch) - if char_code >= 127: - idx_in_line += 1 - else: - idx_in_line += 0.5 - - line_number += 1 - - self.ap.logger.debug("正在保存图片...") - img.save(save_as) - - return save_as diff --git a/pkg/qqbot/longtext/strategy.py b/pkg/qqbot/longtext/strategy.py deleted file mode 100644 index 5c6bfb9c..00000000 --- a/pkg/qqbot/longtext/strategy.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations -import abc -import typing - -import mirai -from mirai.models.message import MessageComponent - -from ...core import app - - -class LongTextStrategy(metaclass=abc.ABCMeta): - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - pass - - @abc.abstractmethod - async def process(self, message: str) -> list[MessageComponent]: - return [] diff --git a/pkg/qqbot/resprule/__init__.py b/pkg/qqbot/resprule/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/resprule/entities.py b/pkg/qqbot/resprule/entities.py deleted file mode 100644 index ffee3081..00000000 --- a/pkg/qqbot/resprule/entities.py +++ /dev/null @@ -1,9 +0,0 @@ -import pydantic -import mirai - - -class RuleJudgeResult(pydantic.BaseModel): - - matching: bool = False - - replacement: mirai.MessageChain = None diff --git a/pkg/qqbot/resprule/resprule.py b/pkg/qqbot/resprule/resprule.py deleted file mode 100644 index 9ea8321d..00000000 --- a/pkg/qqbot/resprule/resprule.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import mirai - -from ...core import app -from . import entities, rule -from .rules import atbot, prefix, regexp, random - - -class GroupRespondRuleChecker: - """群组响应规则检查器 - """ - - ap: app.Application - - rule_matchers: list[rule.GroupRespondRule] - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - """初始化检查器 - """ - self.rule_matchers = [ - atbot.AtBotRule(self.ap), - prefix.PrefixRule(self.ap), - regexp.RegExpRule(self.ap), - random.RandomRespRule(self.ap), - ] - - for rule_matcher in self.rule_matchers: - await rule_matcher.initialize() - - async def check( - self, - message_text: str, - message_chain: mirai.MessageChain, - launcher_id: int, - sender_id: int, - ) -> entities.RuleJudgeResult: - """检查消息是否匹配规则 - """ - rules = self.ap.cfg_mgr.data['response_rules'] - - use_rule = rules['default'] - - if str(launcher_id) in use_rule: - use_rule = use_rule[str(launcher_id)] - - for rule_matcher in self.rule_matchers: - res = await rule_matcher.match(message_text, message_chain, use_rule) - if res.matching: - return res - - return entities.RuleJudgeResult( - matching=False, - replacement=message_chain - ) diff --git a/pkg/qqbot/resprule/rule.py b/pkg/qqbot/resprule/rule.py deleted file mode 100644 index e530d063..00000000 --- a/pkg/qqbot/resprule/rule.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -import abc - -import mirai - -from ...core import app -from . import entities - - -class GroupRespondRule(metaclass=abc.ABCMeta): - """群组响应规则的抽象类 - """ - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - pass - - @abc.abstractmethod - async def match( - self, - message_text: str, - message_chain: mirai.MessageChain, - rule_dict: dict - ) -> entities.RuleJudgeResult: - """判断消息是否匹配规则 - """ - raise NotImplementedError diff --git a/pkg/qqbot/resprule/rules/__init__.py b/pkg/qqbot/resprule/rules/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/resprule/rules/atbot.py b/pkg/qqbot/resprule/rules/atbot.py deleted file mode 100644 index eefc4891..00000000 --- a/pkg/qqbot/resprule/rules/atbot.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import mirai - -from .. import rule as rule_model -from .. import entities - - -class AtBotRule(rule_model.GroupRespondRule): - - async def match( - self, - message_text: str, - message_chain: mirai.MessageChain, - rule_dict: dict - ) -> entities.RuleJudgeResult: - - if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']: - message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id)) - return entities.RuleJudgeResult( - matching=True, - replacement=message_chain, - ) - - return entities.RuleJudgeResult( - matching=False, - replacement = message_chain - ) diff --git a/pkg/qqbot/resprule/rules/prefix.py b/pkg/qqbot/resprule/rules/prefix.py deleted file mode 100644 index 31ead5ab..00000000 --- a/pkg/qqbot/resprule/rules/prefix.py +++ /dev/null @@ -1,29 +0,0 @@ -import mirai - -from .. import rule as rule_model -from .. import entities - - -class PrefixRule(rule_model.GroupRespondRule): - - async def match( - self, - message_text: str, - message_chain: mirai.MessageChain, - rule_dict: dict - ) -> entities.RuleJudgeResult: - prefixes = rule_dict['prefix'] - - for prefix in prefixes: - if message_text.startswith(prefix): - return entities.RuleJudgeResult( - matching=True, - replacement=mirai.MessageChain([ - mirai.Plain(message_text[len(prefix):]) - ]), - ) - - return entities.RuleJudgeResult( - matching=False, - replacement=message_chain - ) diff --git a/pkg/qqbot/resprule/rules/random.py b/pkg/qqbot/resprule/rules/random.py deleted file mode 100644 index 1e8354b5..00000000 --- a/pkg/qqbot/resprule/rules/random.py +++ /dev/null @@ -1,22 +0,0 @@ -import random - -import mirai - -from .. import rule as rule_model -from .. import entities - - -class RandomRespRule(rule_model.GroupRespondRule): - - async def match( - self, - message_text: str, - message_chain: mirai.MessageChain, - rule_dict: dict - ) -> entities.RuleJudgeResult: - random_rate = rule_dict['random_rate'] - - return entities.RuleJudgeResult( - matching=random.random() < random_rate, - replacement=message_chain - ) \ No newline at end of file diff --git a/pkg/qqbot/resprule/rules/regexp.py b/pkg/qqbot/resprule/rules/regexp.py deleted file mode 100644 index 0d621fe4..00000000 --- a/pkg/qqbot/resprule/rules/regexp.py +++ /dev/null @@ -1,31 +0,0 @@ -import re - -import mirai - -from .. import rule as rule_model -from .. import entities - - -class RegExpRule(rule_model.GroupRespondRule): - - async def match( - self, - message_text: str, - message_chain: mirai.MessageChain, - rule_dict: dict - ) -> entities.RuleJudgeResult: - regexps = rule_dict['regexp'] - - for regexp in regexps: - match = re.match(regexp, message_text) - - if match: - return entities.RuleJudgeResult( - matching=True, - replacement=message_chain, - ) - - return entities.RuleJudgeResult( - matching=False, - replacement=message_chain - ) From 411034902aadb916f651263fbb211cfe98bcb383 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sat, 27 Jan 2024 00:05:55 +0800 Subject: [PATCH 03/10] =?UTF-8?q?feat:=20=E5=90=AF=E5=8A=A8=E6=97=B6?= =?UTF-8?q?=E5=B1=95=E7=A4=BAasciiart?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- start.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/start.py b/start.py index b56ea9e9..ead77490 100644 --- a/start.py +++ b/start.py @@ -1,7 +1,15 @@ import asyncio -from pkg.core import boot +asciiart = r""" + ___ ___ _ _ ___ ___ _____ + / _ \ / __| |_ __ _| |_ / __| _ \_ _| +| (_) | (__| ' \/ _` | _| (_ | _/ | | + \__\_\\___|_||_\__,_|\__|\___|_| |_| +""" if __name__ == '__main__': + print(asciiart) + + from pkg.core import boot asyncio.run(boot.main()) From 850a4eeb7c166c3d0d363c5fc3b516ca6cc4e853 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sat, 27 Jan 2024 00:06:38 +0800 Subject: [PATCH 04/10] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84openai?= =?UTF-8?q?=E5=8C=85=E5=9F=BA=E7=A1=80=E7=BB=84=E4=BB=B6=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/config/manager.py | 2 + pkg/core/app.py | 9 ++ pkg/core/boot.py | 18 ++- pkg/core/bootutils/config.py | 2 + pkg/core/controller.py | 152 ++++++++++++++++------ pkg/openai/entities.py | 31 +++++ pkg/openai/requester/__init__.py | 0 pkg/openai/requester/api.py | 31 +++++ pkg/openai/requester/apis/__init__.py | 0 pkg/openai/requester/apis/chatcmpl.py | 32 +++++ pkg/openai/requester/entities.py | 23 ++++ pkg/openai/requester/modelmgr.py | 40 ++++++ pkg/openai/requester/token.py | 25 ++++ pkg/openai/session/__init__.py | 0 pkg/openai/session/entities.py | 50 +++++++ pkg/openai/session/sessionmgr.py | 50 +++++++ pkg/openai/sysprompt/__init__.py | 0 pkg/openai/sysprompt/entities.py | 14 ++ pkg/openai/sysprompt/loader.py | 32 +++++ pkg/openai/sysprompt/loaders/__init__.py | 0 pkg/openai/sysprompt/loaders/scenario.py | 43 ++++++ pkg/openai/sysprompt/loaders/single.py | 42 ++++++ pkg/openai/sysprompt/sysprompt.py | 43 ++++++ pkg/pipeline/process/__init__.py | 0 pkg/pipeline/process/handler.py | 25 ++++ pkg/pipeline/process/handlers/__init__.py | 0 pkg/pipeline/process/handlers/chat.py | 38 ++++++ pkg/pipeline/process/handlers/command.py | 35 +++++ pkg/pipeline/process/process.py | 38 ++++++ pkg/pipeline/respback/__init__.py | 0 pkg/pipeline/respback/respback.py | 29 +++++ pkg/pipeline/stage.py | 6 +- pkg/pipeline/stagemgr.py | 16 +++ pkg/qqbot/manager.py | 16 --- pkg/qqbot/process.py | 2 +- 35 files changed, 782 insertions(+), 62 deletions(-) create mode 100644 pkg/openai/entities.py create mode 100644 pkg/openai/requester/__init__.py create mode 100644 pkg/openai/requester/api.py create mode 100644 pkg/openai/requester/apis/__init__.py create mode 100644 pkg/openai/requester/apis/chatcmpl.py create mode 100644 pkg/openai/requester/entities.py create mode 100644 pkg/openai/requester/modelmgr.py create mode 100644 pkg/openai/requester/token.py create mode 100644 pkg/openai/session/__init__.py create mode 100644 pkg/openai/session/entities.py create mode 100644 pkg/openai/session/sessionmgr.py create mode 100644 pkg/openai/sysprompt/__init__.py create mode 100644 pkg/openai/sysprompt/entities.py create mode 100644 pkg/openai/sysprompt/loader.py create mode 100644 pkg/openai/sysprompt/loaders/__init__.py create mode 100644 pkg/openai/sysprompt/loaders/scenario.py create mode 100644 pkg/openai/sysprompt/loaders/single.py create mode 100644 pkg/openai/sysprompt/sysprompt.py create mode 100644 pkg/pipeline/process/__init__.py create mode 100644 pkg/pipeline/process/handler.py create mode 100644 pkg/pipeline/process/handlers/__init__.py create mode 100644 pkg/pipeline/process/handlers/chat.py create mode 100644 pkg/pipeline/process/handlers/command.py create mode 100644 pkg/pipeline/process/process.py create mode 100644 pkg/pipeline/respback/__init__.py create mode 100644 pkg/pipeline/respback/respback.py diff --git a/pkg/config/manager.py b/pkg/config/manager.py index e343b0c2..7e52d7b0 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import model as file_model from ..utils import context from .impls import pymodule, json as json_file diff --git a/pkg/core/app.py b/pkg/core/app.py index 8c0a0c58..c0ce12fa 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -5,6 +5,9 @@ from ..qqbot import manager as qqbot_mgr from ..openai import manager as openai_mgr +from ..openai.session import sessionmgr as llm_session_mgr +from ..openai.requester import modelmgr as llm_model_mgr +from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr @@ -18,6 +21,12 @@ class Application: llm_mgr: openai_mgr.OpenAIInteract = None + sess_mgr: llm_session_mgr.SessionManager = None + + model_mgr: llm_model_mgr.ModelManager = None + + prompt_mgr: llm_prompt_mgr.PromptManager = None + cfg_mgr: config_mgr.ConfigManager = None tips_mgr: config_mgr.ConfigManager = None diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 10fc51b3..a74615ec 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -15,7 +15,9 @@ from ..audit import identifier from ..database import manager as db_mgr from ..openai import manager as llm_mgr -from ..openai import session as llm_session +from ..openai.session import sessionmgr as llm_session_mgr +from ..openai.requester import modelmgr as llm_model_mgr +from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr from ..qqbot.cmds import aamgr as im_cmd_aamgr @@ -112,8 +114,18 @@ async def make_app() -> app.Application: llm_mgr_inst = llm_mgr.OpenAIInteract(ap) ap.llm_mgr = llm_mgr_inst - # TODO make it async - llm_session.load_sessions() + + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) + await llm_model_mgr_inst.initialize() + ap.model_mgr = llm_model_mgr_inst + + llm_session_mgr_inst = llm_session_mgr.SessionManager(ap) + await llm_session_mgr_inst.initialize() + ap.sess_mgr = llm_session_mgr_inst + + llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap) + await llm_prompt_mgr_inst.initialize() + ap.prompt_mgr = llm_prompt_mgr_inst im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) await im_mgr_inst.initialize() diff --git a/pkg/core/bootutils/config.py b/pkg/core/bootutils/config.py index f1471ae5..0addff08 100644 --- a/pkg/core/bootutils/config.py +++ b/pkg/core/bootutils/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from ...config import manager as config_mgr diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 2470cbbd..ada46f73 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import typing import traceback from . import app, entities @@ -24,25 +25,115 @@ def __init__(self, ap: app.Application): async def consumer(self): """事件处理循环 """ - while True: - selected_query: entities.Query = None - - # 取请求 - async with self.ap.query_pool: - queries: list[entities.Query] = self.ap.query_pool.queries - - if queries: - selected_query = queries.pop(0) # FCFS - else: - await self.ap.query_pool.condition.wait() - continue - - if selected_query: - async def _process_query(selected_query): - async with self.semaphore: - await self.process_query(selected_query) - - asyncio.create_task(_process_query(selected_query)) + try: + while True: + selected_query: entities.Query = None + + # 取请求 + async with self.ap.query_pool: + queries: list[entities.Query] = self.ap.query_pool.queries + + for query in queries: + session = await self.ap.sess_mgr.get_session(query) + self.ap.logger.debug(f"Checking query {query} session {session}") + + if not session.semaphore.locked(): + selected_query = query + await session.semaphore.acquire() + + break + + if selected_query: # 找到了 + queries.remove(selected_query) + else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限 + await self.ap.query_pool.condition.wait() + continue + + if selected_query: + async def _process_query(selected_query): + async with self.semaphore: # 总并发上限 + await self.process_query(selected_query) + + async with self.ap.query_pool: + (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + # 通知其他协程,有新的请求可以处理了 + self.ap.query_pool.condition.notify_all() + + asyncio.create_task(_process_query(selected_query)) + except Exception as e: + self.ap.logger.error(f"事件处理循环出错: {e}") + traceback.print_exc() + + async def _check_output(self, result: pipeline_entities.StageProcessResult): + """检查输出 + """ + if result.user_notice: + await self.ap.im_mgr.send( + result.user_notice + ) + if result.debug_notice: + self.ap.logger.debug(result.debug_notice) + if result.console_notice: + self.ap.logger.info(result.console_notice) + + async def _execute_from_stage( + self, + stage_index: int, + query: entities.Query, + ): + """从指定阶段开始执行 + + 如何看懂这里为什么这么写? + 去问 GPT-4: + Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], + 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, + 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 + Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: + + A B C D E F G + + 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: + + A B C D E F G + + 现在假设C返回的是AsyncGenerator,那么执行顺序是: + + A B C D E F G C D E F G C D E F G ... + Q3: 但是如果不止一个stage会返回生成器呢? + """ + i = stage_index + + while i < len(self.ap.stage_mgr.stage_containers): + stage_container = self.ap.stage_mgr.stage_containers[i] + + result = await stage_container.inst.process(query, stage_container.inst_name) + + + if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + await self._check_output(result) + + if result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif result.result_type == pipeline_entities.ResultType.CONTINUE: + query = result.new_query + elif isinstance(result, typing.AsyncGenerator): # 生成器 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + + async for sub_result in result: + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + await self._check_output(sub_result) + + if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + query = sub_result.new_query + await self._execute_from_stage(i + 1, query) + break + + i += 1 async def process_query(self, query: entities.Query): """处理请求 @@ -50,28 +141,7 @@ async def process_query(self, query: entities.Query): self.ap.logger.debug(f"Processing query {query}") try: - for stage_container in self.ap.stage_mgr.stage_containers: - res = await stage_container.inst.process(query, stage_container.inst_name) - - self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}") - - if res.user_notice: - await self.ap.im_mgr.send( - query.message_event, - res.user_notice - ) - if res.debug_notice: - self.ap.logger.debug(res.debug_notice) - if res.console_notice: - self.ap.logger.info(res.console_notice) - - if res.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif res.result_type == pipeline_entities.ResultType.CONTINUE: - query = res.new_query - continue - + await self._execute_from_stage(0, query) except Exception as e: self.ap.logger.error(f"处理请求时出错 {query}: {e}") traceback.print_exc() diff --git a/pkg/openai/entities.py b/pkg/openai/entities.py new file mode 100644 index 00000000..58f48d95 --- /dev/null +++ b/pkg/openai/entities.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import typing +import enum +import pydantic + + +class MessageRole(enum.Enum): + + SYSTEM = 'system' + + USER = 'user' + + ASSISTANT = 'assistant' + + FUNCTION = 'function' + + +class FunctionCall(pydantic.BaseModel): + name: str + + args: dict[str, typing.Any] + + +class Message(pydantic.BaseModel): + + role: MessageRole + + content: typing.Optional[str] = None + + function_call: typing.Optional[FunctionCall] = None diff --git a/pkg/openai/requester/__init__.py b/pkg/openai/requester/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/requester/api.py b/pkg/openai/requester/api.py new file mode 100644 index 00000000..5dd0abf2 --- /dev/null +++ b/pkg/openai/requester/api.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import abc +import typing + +from ...core import app +from ...core import entities as core_entities +from .. import entities as llm_entities +from ..session import entities as session_entities + +class LLMAPIRequester(metaclass=abc.ABCMeta): + """LLM API请求器 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def request( + self, + query: core_entities.Query, + conversation: session_entities.Conversation, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求 + """ + raise NotImplementedError diff --git a/pkg/openai/requester/apis/__init__.py b/pkg/openai/requester/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py new file mode 100644 index 00000000..5b6d2297 --- /dev/null +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import asyncio +import typing + +import openai + +from .. import api +from ....core import entities as core_entities +from ... import entities as llm_entities +from ...session import entities as session_entities + + +class OpenAIChatCompletion(api.LLMAPIRequester): + + client: openai.Client + + async def initialize(self): + self.client = openai.Client( + base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'], + timeout=self.ap.cfg_mgr.data['process_message_timeout'] + ) + + async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求 + """ + await asyncio.sleep(10) + + yield llm_entities.Message( + role=llm_entities.MessageRole.ASSISTANT, + content="hello" + ) diff --git a/pkg/openai/requester/entities.py b/pkg/openai/requester/entities.py new file mode 100644 index 00000000..adc86677 --- /dev/null +++ b/pkg/openai/requester/entities.py @@ -0,0 +1,23 @@ +import typing + +import pydantic + +from . import api +from . import token + + +class LLMModelInfo(pydantic.BaseModel): + """模型""" + + name: str + + provider: str + + token_mgr: token.TokenManager + + requester: api.LLMAPIRequester + + function_call_supported: typing.Optional[bool] = False + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/requester/modelmgr.py b/pkg/openai/requester/modelmgr.py new file mode 100644 index 00000000..cc606b03 --- /dev/null +++ b/pkg/openai/requester/modelmgr.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from . import entities +from ...core import app + +from .apis import chatcmpl +from . import token + + +class ModelManager: + + ap: app.Application + + model_list: list[entities.LLMModelInfo] + + def __init__(self, ap: app.Application): + self.ap = ap + self.model_list = [] + + async def initialize(self): + openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) + openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values()) + + self.model_list.append( + entities.LLMModelInfo( + name="gpt-3.5-turbo", + provider="openai", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + function_call_supported=True + ) + ) + + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: + """通过名称获取模型 + """ + for model in self.model_list: + if model.name == name: + return model + raise ValueError(f"Model {name} not found") \ No newline at end of file diff --git a/pkg/openai/requester/token.py b/pkg/openai/requester/token.py new file mode 100644 index 00000000..9277c1a6 --- /dev/null +++ b/pkg/openai/requester/token.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import typing + +import pydantic + + +class TokenManager(): + + provider: str + + tokens: list[str] + + using_token_index: typing.Optional[int] = 0 + + def __init__(self, provider: str, tokens: list[str]): + self.provider = provider + self.tokens = tokens + self.using_token_index = 0 + + def get_token(self) -> str: + return self.tokens[self.using_token_index] + + def next_token(self): + self.using_token_index = (self.using_token_index + 1) % len(self.tokens) diff --git a/pkg/openai/session/__init__.py b/pkg/openai/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/session/entities.py b/pkg/openai/session/entities.py new file mode 100644 index 00000000..49ddb845 --- /dev/null +++ b/pkg/openai/session/entities.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import datetime +import asyncio +import typing + +import pydantic + +from ..sysprompt import entities as sysprompt_entities +from .. import entities as llm_entities +from ..requester import entities +from ...core import entities as core_entities + + +class Conversation(pydantic.BaseModel): + """对话""" + + prompt: sysprompt_entities.Prompt + + messages: list[llm_entities.Message] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + use_model: entities.LLMModelInfo + + +class Session(pydantic.BaseModel): + """会话""" + launcher_type: core_entities.LauncherTypes + + launcher_id: int + + sender_id: typing.Optional[int] = 0 + + use_prompt_name: typing.Optional[str] = 'default' + + using_conversation: typing.Optional[Conversation] = None + + conversations: typing.Optional[list[Conversation]] = [] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + semaphore: typing.Optional[asyncio.Semaphore] = None + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/session/sessionmgr.py b/pkg/openai/session/sessionmgr.py new file mode 100644 index 00000000..8aff6e02 --- /dev/null +++ b/pkg/openai/session/sessionmgr.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio + +from ...core import app, entities as core_entities +from . import entities + + +class SessionManager: + + ap: app.Application + + session_list: list[entities.Session] + + def __init__(self, ap: app.Application): + self.ap = ap + self.session_list = [] + + async def initialize(self): + pass + + async def get_session(self, query: core_entities.Query) -> entities.Session: + """获取会话 + """ + for session in self.session_list: + if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: + return session + + session = entities.Session( + launcher_type=query.launcher_type, + launcher_id=query.launcher_id, + semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000) + ) + self.session_list.append(session) + return session + + async def get_conversation(self, session: entities.Session) -> entities.Conversation: + if not session.conversations: + session.conversations = [] + + if session.using_conversation is None: + conversation = entities.Conversation( + prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), + messages=[], + use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), + ) + session.conversations.append(conversation) + session.using_conversation = conversation + + return session.using_conversation diff --git a/pkg/openai/sysprompt/__init__.py b/pkg/openai/sysprompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/sysprompt/entities.py b/pkg/openai/sysprompt/entities.py new file mode 100644 index 00000000..43cd3bf7 --- /dev/null +++ b/pkg/openai/sysprompt/entities.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import typing +import pydantic + +from ...openai import entities + + +class Prompt(pydantic.BaseModel): + """供AI使用的Prompt""" + + name: str + + messages: list[entities.Message] diff --git a/pkg/openai/sysprompt/loader.py b/pkg/openai/sysprompt/loader.py new file mode 100644 index 00000000..ca9e8730 --- /dev/null +++ b/pkg/openai/sysprompt/loader.py @@ -0,0 +1,32 @@ +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class PromptLoader(metaclass=abc.ABCMeta): + """Prompt加载器抽象类 + """ + + ap: app.Application + + prompts: list[entities.Prompt] + + def __init__(self, ap: app.Application): + self.ap = ap + self.prompts = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def load(self): + """加载Prompt + """ + raise NotImplementedError + + def get_prompts(self) -> list[entities.Prompt]: + """获取Prompt列表 + """ + return self.prompts diff --git a/pkg/openai/sysprompt/loaders/__init__.py b/pkg/openai/sysprompt/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/sysprompt/loaders/scenario.py b/pkg/openai/sysprompt/loaders/scenario.py new file mode 100644 index 00000000..e0c2bd33 --- /dev/null +++ b/pkg/openai/sysprompt/loaders/scenario.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import json +import os + +from .. import loader +from .. import entities +from ....openai import entities as llm_entities + + +class ScenarioPromptLoader(loader.PromptLoader): + """加载scenario目录下的json""" + + async def load(self): + """加载Prompt + """ + for file in os.listdir("scenarios"): + with open("scenarios/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + file_json = json.loads(file_str) + messages = [] + for msg in file_json["prompt"]: + role = llm_entities.MessageRole.SYSTEM + if "role" in msg: + if msg["role"] == "user": + role = llm_entities.MessageRole.USER + elif msg["role"] == "system": + role = llm_entities.MessageRole.SYSTEM + elif msg["role"] == "function": + role = llm_entities.MessageRole.FUNCTION + messages.append( + llm_entities.Message( + role=role, + content=msg['content'], + ) + ) + prompt = entities.Prompt( + name=file_name, + messages=messages + ) + self.prompts.append(prompt) + \ No newline at end of file diff --git a/pkg/openai/sysprompt/loaders/single.py b/pkg/openai/sysprompt/loaders/single.py new file mode 100644 index 00000000..ad37d878 --- /dev/null +++ b/pkg/openai/sysprompt/loaders/single.py @@ -0,0 +1,42 @@ +from __future__ import annotations +import os + +from .. import loader +from .. import entities +from ....openai import entities as llm_entities + + +class SingleSystemPromptLoader(loader.PromptLoader): + """配置文件中的单条system prompt的prompt加载器 + """ + + async def load(self): + """加载Prompt + """ + + for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items(): + prompt = entities.Prompt( + name=name, + messages=[ + llm_entities.Message( + role=llm_entities.MessageRole.SYSTEM, + content=cnt + ) + ] + ) + self.prompts.append(prompt) + + for file in os.listdir("prompts"): + with open("prompts/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + prompt = entities.Prompt( + name=file_name, + messages=[ + llm_entities.Message( + role=llm_entities.MessageRole.SYSTEM, + content=file_str + ) + ] + ) + self.prompts.append(prompt) diff --git a/pkg/openai/sysprompt/sysprompt.py b/pkg/openai/sysprompt/sysprompt.py new file mode 100644 index 00000000..050f6639 --- /dev/null +++ b/pkg/openai/sysprompt/sysprompt.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from ...core import app +from . import loader +from .loaders import single, scenario + + +class PromptManager: + + ap: app.Application + + loader_inst: loader.PromptLoader + + default_prompt: str = 'default' + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + + loader_map = { + "normal": single.SingleSystemPromptLoader, + "full_scenario": scenario.ScenarioPromptLoader + } + + loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']] + + self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + + await self.loader_inst.initialize() + await self.loader_inst.load() + + def get_all_prompts(self) -> list[loader.entities.Prompt]: + """获取所有Prompt + """ + return self.loader_inst.get_prompts() + + async def get_prompt(self, name: str) -> loader.entities.Prompt: + """获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name == name: + return prompt diff --git a/pkg/pipeline/process/__init__.py b/pkg/pipeline/process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py new file mode 100644 index 00000000..6d19e039 --- /dev/null +++ b/pkg/pipeline/process/handler.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import abc + +from ...core import app +from ...core import entities as core_entities +from .. import entities + + +class MessageHandler(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def handle( + self, + query: core_entities.Query, + ) -> entities.StageProcessResult: + raise NotImplementedError diff --git a/pkg/pipeline/process/handlers/__init__.py b/pkg/pipeline/process/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py new file mode 100644 index 00000000..ebe958bf --- /dev/null +++ b/pkg/pipeline/process/handlers/chat.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import typing + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities + + +class ChatMessageHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + # 取session + # 取conversation + # 调API + # 生成器 + session = await self.ap.sess_mgr.get_session(query) + + conversation = await self.ap.sess_mgr.get_conversation(session) + + async for result in conversation.use_model.requester.request(query, conversation): + query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + + + diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py new file mode 100644 index 00000000..c5fecb67 --- /dev/null +++ b/pkg/pipeline/process/handlers/command.py @@ -0,0 +1,35 @@ +from __future__ import annotations +import typing + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities + + +class CommandHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain('CommandHandler') + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain('The Second Message') + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py new file mode 100644 index 00000000..29051431 --- /dev/null +++ b/pkg/pipeline/process/process.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from ...core import app, entities as core_entities +from . import handler +from .handlers import chat, command +from .. import entities +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("MessageProcessor") +class Processor(stage.PipelineStage): + + cmd_handler: handler.MessageHandler + + chat_handler: handler.MessageHandler + + async def initialize(self): + self.cmd_handler = command.CommandHandler(self.ap) + self.chat_handler = chat.ChatMessageHandler(self.ap) + + await self.cmd_handler.initialize() + await self.chat_handler.initialize() + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + message_text = str(query.message_chain).strip() + + if message_text.startswith('!') or message_text.startswith('!'): + return self.cmd_handler.handle(query) + else: + return self.chat_handler.handle(query) diff --git a/pkg/pipeline/respback/__init__.py b/pkg/pipeline/respback/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py new file mode 100644 index 00000000..4dbddaa5 --- /dev/null +++ b/pkg/pipeline/respback/respback.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("SendResponseBackStage") +class SendResponseBackStage(stage.PipelineStage): + """发送响应消息 + """ + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + """处理 + """ + + await self.ap.im_mgr.send( + query.message_event, + query.resp_message_chain + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 84a0339d..56c092b5 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import typing from ..core import app, entities as core_entities from . import entities @@ -37,7 +38,10 @@ async def process( self, query: core_entities.Query, stage_inst_name: str, - ) -> entities.StageProcessResult: + ) -> typing.Union[ + entities.StageProcessResult, + typing.AsyncGenerator[entities.StageProcessResult, None], + ]: """处理 """ raise NotImplementedError diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index f5407a2e..1ff36329 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -7,7 +7,20 @@ from .resprule import resprule from .bansess import bansess from .cntfilter import cntfilter +from .process import process from .longtext import longtext +from .respback import respback + + +stage_order = [ + "GroupRespondRuleCheckStage", + "BanSessionCheckStage", + "PreContentFilterStage", + "MessageProcessor", + "PostContentFilterStage", + "LongTextProcessStage", + "SendResponseBackStage", +] class StageInstContainer(): @@ -45,3 +58,6 @@ async def initialize(self): for stage_containers in self.stage_containers: await stage_containers.inst.initialize() + + # 按照 stage_order 排序 + self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name)) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index b16450e8..7794663a 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -18,10 +18,6 @@ from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter -from .resprule import resprule -from .bansess import bansess -from .cntfilter import cntfilter -from .longtext import longtext from .ratelim import ratelim from ..core import app, entities as core_entities @@ -41,30 +37,18 @@ class QQBotManager: # modern ap: app.Application = None - bansess_mgr: bansess.SessionBanManager = None - cntfilter_mgr: cntfilter.ContentFilterManager = None - longtext_pcs: longtext.LongTextProcessor = None - resprule_chkr: resprule.GroupRespondRuleChecker = None ratelimiter: ratelim.RateLimiter = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data self.ap = ap - self.bansess_mgr = bansess.SessionBanManager(ap) - self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) - self.longtext_pcs = longtext.LongTextProcessor(ap) - self.resprule_chkr = resprule.GroupRespondRuleChecker(ap) self.ratelimiter = ratelim.RateLimiter(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] async def initialize(self): - await self.bansess_mgr.initialize() - await self.cntfilter_mgr.initialize() - await self.longtext_pcs.initialize() - await self.resprule_chkr.initialize() await self.ratelimiter.initialize() config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 65de8d52..a8359be5 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -15,7 +15,7 @@ from ..plugin import models as plugin_models import tips as tips_custom from ..core import app -from .cntfilter import entities +# from .cntfilter import entities processing = [] From f10af09bd27681d910bef6cc80a4444df2829e00 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sat, 27 Jan 2024 21:50:40 +0800 Subject: [PATCH 05/10] =?UTF-8?q?refactor:=20AI=E5=AF=B9=E8=AF=9D=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 18 ++- pkg/core/boot.py | 8 +- pkg/openai/entities.py | 24 ++-- pkg/openai/requester/apis/chatcmpl.py | 134 ++++++++++++++++++++--- pkg/openai/requester/modelmgr.py | 3 +- pkg/openai/session/entities.py | 3 + pkg/openai/session/sessionmgr.py | 3 +- pkg/openai/sysprompt/loaders/scenario.py | 9 +- pkg/openai/sysprompt/loaders/single.py | 4 +- pkg/openai/tools/__init__.py | 0 pkg/openai/tools/entities.py | 35 ++++++ pkg/openai/tools/toolmgr.py | 99 +++++++++++++++++ pkg/pipeline/longtext/longtext.py | 4 +- pkg/pipeline/process/handlers/chat.py | 6 +- 14 files changed, 306 insertions(+), 44 deletions(-) create mode 100644 pkg/openai/tools/__init__.py create mode 100644 pkg/openai/tools/entities.py create mode 100644 pkg/openai/tools/toolmgr.py diff --git a/pkg/core/app.py b/pkg/core/app.py index c0ce12fa..c9d06e15 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -8,6 +8,7 @@ from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr +from ..openai.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr @@ -27,6 +28,8 @@ class Application: prompt_mgr: llm_prompt_mgr.PromptManager = None + tool_mgr: llm_tool_mgr.ToolManager = None + cfg_mgr: config_mgr.ConfigManager = None tips_mgr: config_mgr.ConfigManager = None @@ -46,10 +49,21 @@ class Application: def __init__(self): pass - async def run(self): - # TODO make it async + async def initialize(self): plugin_host.initialize_plugins() + # 把现有的所有内容函数加到toolmgr里 + for func in plugin_host.__callable_functions__: + print(func) + self.tool_mgr.register_legacy_function( + name=func['name'], + description=func['description'], + parameters=func['parameters'], + func=plugin_host.__function_inst_map__[func['name']] + ) + + async def run(self): + tasks = [ asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.ctrl.run()) diff --git a/pkg/core/boot.py b/pkg/core/boot.py index a74615ec..c06cc6cd 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -18,6 +18,7 @@ from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr +from ..openai.tools import toolmgr as llm_tool_mgr from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr from ..qqbot.cmds import aamgr as im_cmd_aamgr @@ -127,6 +128,10 @@ async def make_app() -> app.Application: await llm_prompt_mgr_inst.initialize() ap.prompt_mgr = llm_prompt_mgr_inst + llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) + await llm_tool_mgr_inst.initialize() + ap.tool_mgr = llm_tool_mgr_inst + im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst @@ -140,7 +145,8 @@ async def make_app() -> app.Application: # TODO make it async plugin_host.load_plugins() - # plugin_host.initialize_plugins() + + await ap.initialize() return ap diff --git a/pkg/openai/entities.py b/pkg/openai/entities.py index 58f48d95..2dd5804b 100644 --- a/pkg/openai/entities.py +++ b/pkg/openai/entities.py @@ -5,27 +5,29 @@ import pydantic -class MessageRole(enum.Enum): - - SYSTEM = 'system' - - USER = 'user' +class FunctionCall(pydantic.BaseModel): + name: str - ASSISTANT = 'assistant' + arguments: str - FUNCTION = 'function' +class ToolCall(pydantic.BaseModel): + id: str -class FunctionCall(pydantic.BaseModel): - name: str + type: str - args: dict[str, typing.Any] + function: FunctionCall class Message(pydantic.BaseModel): + role: str - role: MessageRole + name: typing.Optional[str] = None content: typing.Optional[str] = None function_call: typing.Optional[FunctionCall] = None + + tool_calls: typing.Optional[list[ToolCall]] = None + + tool_call_id: typing.Optional[str] = None diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py index 5b6d2297..24ff2d7e 100644 --- a/pkg/openai/requester/apis/chatcmpl.py +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -2,8 +2,10 @@ import asyncio import typing +import json import openai +import openai.types.chat.chat_completion as chat_completion from .. import api from ....core import entities as core_entities @@ -12,21 +14,127 @@ class OpenAIChatCompletion(api.LLMAPIRequester): - - client: openai.Client + client: openai.AsyncClient async def initialize(self): - self.client = openai.Client( - base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'], - timeout=self.ap.cfg_mgr.data['process_message_timeout'] + self.client = openai.AsyncClient( + api_key="", + base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"], + timeout=self.ap.cfg_mgr.data["process_message_timeout"], ) - async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求 - """ - await asyncio.sleep(10) + async def _req( + self, + args: dict, + ) -> chat_completion.ChatCompletion: + self.ap.logger.debug(f"req chat_completion with args {args}") + return await self.client.chat.completions.create(**args) - yield llm_entities.Message( - role=llm_entities.MessageRole.ASSISTANT, - content="hello" - ) + async def _make_msg( + self, + chat_completion: chat_completion.ChatCompletion, + ) -> llm_entities.Message: + chatcmpl_message = chat_completion.choices[0].message.dict() + + message = llm_entities.Message(**chatcmpl_message) + + return message + + async def _closure( + self, + req_messages: list[dict], + conversation: session_entities.Conversation, + user_text: str = None, + function_ret: str = None, + ) -> llm_entities.Message: + self.client.api_key = conversation.use_model.token_mgr.get_token() + + args = self.ap.cfg_mgr.data["completion_api_params"].copy() + args["model"] = conversation.use_model.name + + tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) + # tools = [ + # { + # "type": "function", + # "function": { + # "name": "get_current_weather", + # "description": "Get the current weather in a given location", + # "parameters": { + # "type": "object", + # "properties": { + # "location": { + # "type": "string", + # "description": "The city and state, e.g. San Francisco, CA", + # }, + # "unit": { + # "type": "string", + # "enum": ["celsius", "fahrenheit"], + # }, + # }, + # "required": ["location"], + # }, + # }, + # } + # ] + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message + + async def request( + self, query: core_entities.Query, conversation: session_entities.Conversation + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求""" + + pending_tool_calls = [] + + req_messages = [ + m.dict(exclude_none=True) for m in conversation.prompt.messages + ] + [m.dict(exclude_none=True) for m in conversation.messages] + + req_messages.append({"role": "user", "content": str(query.message_chain)}) + + msg = await self._closure(req_messages, conversation) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) + + while pending_tool_calls: + for tool_call in pending_tool_calls: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg.dict(exclude_none=True)) + + # 处理完所有调用,继续请求 + msg = await self._closure(req_messages, conversation) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) diff --git a/pkg/openai/requester/modelmgr.py b/pkg/openai/requester/modelmgr.py index cc606b03..7e6a3b52 100644 --- a/pkg/openai/requester/modelmgr.py +++ b/pkg/openai/requester/modelmgr.py @@ -19,7 +19,8 @@ def __init__(self, ap: app.Application): async def initialize(self): openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) - openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values()) + await openai_chat_completion.initialize() + openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values())) self.model_list.append( entities.LLMModelInfo( diff --git a/pkg/openai/session/entities.py b/pkg/openai/session/entities.py index 49ddb845..cbeb72a3 100644 --- a/pkg/openai/session/entities.py +++ b/pkg/openai/session/entities.py @@ -10,6 +10,7 @@ from .. import entities as llm_entities from ..requester import entities from ...core import entities as core_entities +from ..tools import entities as tools_entities class Conversation(pydantic.BaseModel): @@ -25,6 +26,8 @@ class Conversation(pydantic.BaseModel): use_model: entities.LLMModelInfo + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] + class Session(pydantic.BaseModel): """会话""" diff --git a/pkg/openai/session/sessionmgr.py b/pkg/openai/session/sessionmgr.py index 8aff6e02..a1d5d4d9 100644 --- a/pkg/openai/session/sessionmgr.py +++ b/pkg/openai/session/sessionmgr.py @@ -29,7 +29,7 @@ async def get_session(self, query: core_entities.Query) -> entities.Session: session = entities.Session( launcher_type=query.launcher_type, launcher_id=query.launcher_id, - semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000) + semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000), ) self.session_list.append(session) return session @@ -43,6 +43,7 @@ async def get_conversation(self, session: entities.Session) -> entities.Conversa prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), + use_funcs=await self.ap.tool_mgr.get_all_functions(), ) session.conversations.append(conversation) session.using_conversation = conversation diff --git a/pkg/openai/sysprompt/loaders/scenario.py b/pkg/openai/sysprompt/loaders/scenario.py index e0c2bd33..4d54f30f 100644 --- a/pkg/openai/sysprompt/loaders/scenario.py +++ b/pkg/openai/sysprompt/loaders/scenario.py @@ -21,14 +21,9 @@ async def load(self): file_json = json.loads(file_str) messages = [] for msg in file_json["prompt"]: - role = llm_entities.MessageRole.SYSTEM + role = 'system' if "role" in msg: - if msg["role"] == "user": - role = llm_entities.MessageRole.USER - elif msg["role"] == "system": - role = llm_entities.MessageRole.SYSTEM - elif msg["role"] == "function": - role = llm_entities.MessageRole.FUNCTION + role = msg['role'] messages.append( llm_entities.Message( role=role, diff --git a/pkg/openai/sysprompt/loaders/single.py b/pkg/openai/sysprompt/loaders/single.py index ad37d878..1fff5a69 100644 --- a/pkg/openai/sysprompt/loaders/single.py +++ b/pkg/openai/sysprompt/loaders/single.py @@ -19,7 +19,7 @@ async def load(self): name=name, messages=[ llm_entities.Message( - role=llm_entities.MessageRole.SYSTEM, + role='system', content=cnt ) ] @@ -34,7 +34,7 @@ async def load(self): name=file_name, messages=[ llm_entities.Message( - role=llm_entities.MessageRole.SYSTEM, + role='system', content=file_str ) ] diff --git a/pkg/openai/tools/__init__.py b/pkg/openai/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/tools/entities.py b/pkg/openai/tools/entities.py new file mode 100644 index 00000000..b79627e5 --- /dev/null +++ b/pkg/openai/tools/entities.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import abc +import typing +import asyncio + +import pydantic + + +class LLMFunction(pydantic.BaseModel): + """函数""" + + name: str + """函数名""" + + human_desc: str + + description: str + """给LLM识别的函数描述""" + + enable: typing.Optional[bool] = True + + parameters: dict + + func: typing.Callable + """供调用的python异步方法 + + 此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。 + query参数不在parameters中,但在调用时会自动传入。 + 但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中, + 对插件的内容函数进行封装并存到这里来。 + """ + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/tools/toolmgr.py b/pkg/openai/tools/toolmgr.py new file mode 100644 index 00000000..cc160e39 --- /dev/null +++ b/pkg/openai/tools/toolmgr.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import typing + +from ...core import app, entities as core_entities +from . import entities +from ..session import entities as session_entities + + +class ToolManager: + """LLM工具管理器 + """ + + ap: app.Application + + all_functions: list[entities.LLMFunction] + + def __init__(self, ap: app.Application): + self.ap = ap + self.all_functions = [] + + async def initialize(self): + pass + + def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable): + """注册函数 + """ + async def wrapper(query, **kwargs): + return func(**kwargs) + function = entities.LLMFunction( + name=name, + description=description, + human_desc='', + enable=True, + parameters=parameters, + func=wrapper + ) + self.all_functions.append(function) + + async def register_function(self, function: entities.LLMFunction): + """添加函数 + """ + self.all_functions.append(function) + + async def get_function(self, name: str) -> entities.LLMFunction: + """获取函数 + """ + for function in self.all_functions: + if function.name == name: + return function + return None + + async def get_all_functions(self) -> list[entities.LLMFunction]: + """获取所有函数 + """ + return self.all_functions + + async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str: + """生成函数列表 + """ + tools = [] + + for function in conversation.use_funcs: + if function.enable: + function_schema = { + "type": "function", + "function": { + "name": function.name, + "description": function.description, + "parameters": function.parameters + } + } + tools.append(function_schema) + + return tools + + async def execute_func_call( + self, + query: core_entities.Query, + name: str, + parameters: dict + ) -> typing.Any: + """执行函数调用 + """ + + # return "i'm not sure for the args "+str(parameters) + + function = await self.get_function(name) + if function is None: + return None + + parameters = parameters.copy() + + parameters = { + "query": query, + **parameters + } + + return await function.func(**parameters) diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 11144891..72c36cdf 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -50,8 +50,8 @@ async def initialize(self): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']: - query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) + query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index ebe958bf..629c2b11 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -26,13 +26,11 @@ async def handle( conversation = await self.ap.sess_mgr.get_conversation(session) async for result in conversation.use_model.requester.request(query, conversation): + conversation.messages.append(result) + query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))]) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) - - - - From 2a0cf573035ede31a592bea137dd4735592047b3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 00:16:42 +0800 Subject: [PATCH 06/10] =?UTF-8?q?refactor:=20=E5=91=BD=E4=BB=A4=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=9F=BA=E7=A1=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/__init__.py | 0 pkg/command/cmdmgr.py | 104 +++++++++++++++++++++++ pkg/command/entities.py | 43 ++++++++++ pkg/command/errors.py | 12 +++ pkg/command/operator.py | 71 ++++++++++++++++ pkg/command/operators/__init__.py | 0 pkg/command/operators/func.py | 23 +++++ pkg/core/app.py | 4 +- pkg/core/boot.py | 12 ++- pkg/pipeline/process/handlers/command.py | 45 ++++++---- 10 files changed, 289 insertions(+), 25 deletions(-) create mode 100644 pkg/command/__init__.py create mode 100644 pkg/command/cmdmgr.py create mode 100644 pkg/command/entities.py create mode 100644 pkg/command/errors.py create mode 100644 pkg/command/operator.py create mode 100644 pkg/command/operators/__init__.py create mode 100644 pkg/command/operators/func.py diff --git a/pkg/command/__init__.py b/pkg/command/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py new file mode 100644 index 00000000..ae183e90 --- /dev/null +++ b/pkg/command/cmdmgr.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import typing + +from ..core import app, entities as core_entities +from ..openai import entities as llm_entities +from ..openai.session import entities as session_entities +from . import entities, operator, errors + +from .operators import func + + +class CommandManager: + """命令管理器 + """ + + ap: app.Application + + cmd_list: list[operator.CommandOperator] + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + # 实例化所有类 + self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] + + # 设置所有类的子节点 + for cmd in self.cmd_list: + cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] + + # 初始化所有类 + for cmd in self.cmd_list: + await cmd.initialize() + + async def _execute( + self, + context: entities.ExecuteContext, + operator_list: list[operator.CommandOperator], + operator: operator.CommandOperator = None + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + found = False + if len(context.crt_params) > 0: + for operator in operator_list: + if context.crt_params[0] == operator.name \ + or context.crt_params[0] in operator.alias: + found = True + context.crt_command = context.params[0] + context.crt_params = context.params[1:] + + async for ret in self._execute( + context, + operator.children, + operator + ): + yield ret + + if not found: + if operator is None: + yield entities.CommandReturn( + error=errors.CommandNotFoundError(context.crt_command) + ) + else: + if operator.lowest_privilege > context.privilege: + yield entities.CommandReturn( + error=errors.CommandPrivilegeError(context.crt_command) + ) + else: + async for ret in operator.execute(context): + yield ret + + + async def execute( + self, + command_text: str, + query: core_entities.Query, + session: session_entities.Session + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + + privilege = 1 + if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \ + or query.sender_id in self.ap.cfg_mgr['admin_qq']: + privilege = 2 + + ctx = entities.ExecuteContext( + query=query, + session=session, + command_text=command_text, + command='', + crt_command='', + params=command_text.split(' '), + crt_params=command_text.split(' '), + privilege=privilege + ) + + async for ret in self._execute( + ctx, + self.cmd_list + ): + yield ret diff --git a/pkg/command/entities.py b/pkg/command/entities.py new file mode 100644 index 00000000..7fba96e5 --- /dev/null +++ b/pkg/command/entities.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing + +import pydantic +import mirai + +from ..core import app, entities as core_entities +from ..openai.session import entities as session_entities +from . import errors, operator + + +class CommandReturn(pydantic.BaseModel): + + text: typing.Optional[str] + """文本 + """ + + image: typing.Optional[mirai.Image] + + error: typing.Optional[errors.CommandError]= None + + class Config: + arbitrary_types_allowed = True + + +class ExecuteContext(pydantic.BaseModel): + + query: core_entities.Query + + session: session_entities.Session + + command_text: str + + command: str + + crt_command: str + + params: list[str] + + crt_params: list[str] + + privilege: int diff --git a/pkg/command/errors.py b/pkg/command/errors.py new file mode 100644 index 00000000..42c5a8b3 --- /dev/null +++ b/pkg/command/errors.py @@ -0,0 +1,12 @@ + + +class CommandError(Exception): + pass + + +class CommandNotFoundError(CommandError): + pass + + +class CommandPrivilegeError(CommandError): + pass \ No newline at end of file diff --git a/pkg/command/operator.py b/pkg/command/operator.py new file mode 100644 index 00000000..319da55b --- /dev/null +++ b/pkg/command/operator.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app, entities as core_entities +from ..openai.session import entities as session_entities +from . import entities + + +preregistered_operators: list[typing.Type[CommandOperator]] = [] + + +def operator_class( + name: str, + alias: list[str], + help: str, + privilege: int=1, # 1为普通用户,2为管理员 + parent_class: typing.Type[CommandOperator] = None +) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: + def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: + cls.name = name + cls.alias = alias + cls.help = help + cls.parent_class = parent_class + + preregistered_operators.append(cls) + + return cls + + return decorator + + +class CommandOperator(metaclass=abc.ABCMeta): + """命令算子 + """ + + ap: app.Application + + name: str + """名称,搜索到时若符合则使用""" + + alias: list[str] + """同name""" + + help: str + """此节点的帮助信息""" + + parent_class: typing.Type[CommandOperator] + """父节点类。标记以供管理器在初始化时编织父子关系。""" + + lowest_privilege: int = 0 + """最低权限。若权限低于此值,则不予执行。""" + + children: list[CommandOperator] + """子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点, + 若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.children = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + pass diff --git a/pkg/command/operators/__init__.py b/pkg/command/operators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py new file mode 100644 index 00000000..c888831f --- /dev/null +++ b/pkg/command/operators/func.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import AsyncGenerator + +from .. import operator, entities, cmdmgr +from ...plugin import host as plugin_host + + +@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数") +class FuncOperator(operator.CommandOperator): + async def execute( + self, + context: entities.ExecuteContext + ) -> AsyncGenerator[entities.CommandReturn, None]: + reply_str = "当前已加载的内容函数: \n\n" + + index = 1 + for func in plugin_host.__callable_functions__: + reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) + index += 1 + + yield entities.CommandReturn( + text=reply_str + ) \ No newline at end of file diff --git a/pkg/core/app.py b/pkg/core/app.py index c9d06e15..f0693957 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -12,6 +12,7 @@ from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr +from ..command import cmdmgr from ..plugin import host as plugin_host from . import pool, controller from ..pipeline import stagemgr @@ -22,6 +23,8 @@ class Application: llm_mgr: openai_mgr.OpenAIInteract = None + cmd_mgr: cmdmgr.CommandManager = None + sess_mgr: llm_session_mgr.SessionManager = None model_mgr: llm_model_mgr.ModelManager = None @@ -54,7 +57,6 @@ async def initialize(self): # 把现有的所有内容函数加到toolmgr里 for func in plugin_host.__callable_functions__: - print(func) self.tool_mgr.register_legacy_function( name=func['name'], description=func['description'], diff --git a/pkg/core/boot.py b/pkg/core/boot.py index c06cc6cd..2b03a153 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -19,9 +19,8 @@ from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..openai.tools import toolmgr as llm_tool_mgr -from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr -from ..qqbot.cmds import aamgr as im_cmd_aamgr +from ..command import cmdmgr from ..plugin import host as plugin_host from ..utils.center import v2 as center_v2 from ..utils import updater @@ -81,11 +80,6 @@ async def make_app() -> app.Application: if cfg_mgr.data['admin_qq'] == 0: qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq") - # TODO make it async - llm_dprompt.register_all() - im_cmd_aamgr.register_all() - im_cmd_aamgr.apply_privileges() - # 构建组建实例 ap = app.Application() ap.logger = qcg_logger @@ -116,6 +110,10 @@ async def make_app() -> app.Application: llm_mgr_inst = llm_mgr.OpenAIInteract(ap) ap.llm_mgr = llm_mgr_inst + cmd_mgr_inst = cmdmgr.CommandManager(ap) + await cmd_mgr_inst.initialize() + ap.cmd_mgr = cmd_mgr_inst + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) await llm_model_mgr_inst.initialize() ap.model_mgr = llm_model_mgr_inst diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index c5fecb67..cf3e0740 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -16,20 +16,31 @@ async def handle( ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain('CommandHandler') - ]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain('The Second Message') - ]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) \ No newline at end of file + session = await self.ap.sess_mgr.get_session(query) + + command_text = str(query.message_chain).strip()[1:] + + async for ret in self.ap.cmd_mgr.execute( + command_text=command_text, + query=query, + session=session + ): + if ret.error is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(str(ret.error)) + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + if ret.text is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(ret.text) + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) From 1368ee22b2c15ab5bcae98efca1e7c8104582a15 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 18:21:43 +0800 Subject: [PATCH 07/10] =?UTF-8?q?refactor:=20=E5=91=BD=E4=BB=A4=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/cmdmgr.py | 24 ++- pkg/command/errors.py | 27 ++- pkg/command/operator.py | 8 +- pkg/command/operators/cfg.py | 98 ++++++++++ pkg/command/operators/cmd.py | 50 +++++ pkg/command/operators/default.py | 62 ++++++ pkg/command/operators/delc.py | 62 ++++++ pkg/command/operators/func.py | 18 +- pkg/command/operators/help.py | 23 +++ pkg/command/operators/last.py | 36 ++++ pkg/command/operators/list.py | 56 ++++++ pkg/command/operators/next.py | 35 ++++ pkg/command/operators/plugin.py | 239 +++++++++++++++++++++++ pkg/command/operators/prompt.py | 29 +++ pkg/command/operators/resend.py | 34 ++++ pkg/command/operators/reset.py | 23 +++ pkg/command/operators/version.py | 28 +++ pkg/openai/requester/apis/chatcmpl.py | 2 +- pkg/openai/sysprompt/sysprompt.py | 9 +- pkg/pipeline/process/handlers/chat.py | 8 + pkg/pipeline/process/handlers/command.py | 22 ++- 21 files changed, 859 insertions(+), 34 deletions(-) create mode 100644 pkg/command/operators/cfg.py create mode 100644 pkg/command/operators/cmd.py create mode 100644 pkg/command/operators/default.py create mode 100644 pkg/command/operators/delc.py create mode 100644 pkg/command/operators/help.py create mode 100644 pkg/command/operators/last.py create mode 100644 pkg/command/operators/list.py create mode 100644 pkg/command/operators/next.py create mode 100644 pkg/command/operators/plugin.py create mode 100644 pkg/command/operators/prompt.py create mode 100644 pkg/command/operators/resend.py create mode 100644 pkg/command/operators/reset.py create mode 100644 pkg/command/operators/version.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index ae183e90..cff5969c 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -7,7 +7,7 @@ from ..openai.session import entities as session_entities from . import entities, operator, errors -from .operators import func +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version class CommandManager: @@ -41,31 +41,35 @@ async def _execute( ) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行命令 """ + found = False if len(context.crt_params) > 0: - for operator in operator_list: - if context.crt_params[0] == operator.name \ - or context.crt_params[0] in operator.alias: + for oper in operator_list: + if (context.crt_params[0] == oper.name \ + or context.crt_params[0] in oper.alias) \ + and (oper.parent_class is None or oper.parent_class == operator.__class__): found = True - context.crt_command = context.params[0] - context.crt_params = context.params[1:] + + context.crt_command = context.crt_params[0] + context.crt_params = context.crt_params[1:] async for ret in self._execute( context, - operator.children, - operator + oper.children, + oper ): yield ret + break if not found: if operator is None: yield entities.CommandReturn( - error=errors.CommandNotFoundError(context.crt_command) + error=errors.CommandNotFoundError(context.crt_params[0]) ) else: if operator.lowest_privilege > context.privilege: yield entities.CommandReturn( - error=errors.CommandPrivilegeError(context.crt_command) + error=errors.CommandPrivilegeError(operator.name) ) else: async for ret in operator.execute(context): diff --git a/pkg/command/errors.py b/pkg/command/errors.py index 42c5a8b3..5bc253f6 100644 --- a/pkg/command/errors.py +++ b/pkg/command/errors.py @@ -1,12 +1,33 @@ class CommandError(Exception): - pass + + def __init__(self, message: str = None): + self.message = message + + def __str__(self): + return self.message class CommandNotFoundError(CommandError): - pass + + def __init__(self, message: str = None): + super().__init__("未知命令: "+message) class CommandPrivilegeError(CommandError): - pass \ No newline at end of file + + def __init__(self, message: str = None): + super().__init__("权限不足: "+message) + + +class ParamNotEnoughError(CommandError): + + def __init__(self, message: str = None): + super().__init__("参数不足: "+message) + + +class CommandOperationError(CommandError): + + def __init__(self, message: str = None): + super().__init__("操作失败: "+message) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 319da55b..af1a5d6e 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -13,8 +13,9 @@ def operator_class( name: str, - alias: list[str], help: str, + usage: str = None, + alias: list[str] = [], privilege: int=1, # 1为普通用户,2为管理员 parent_class: typing.Type[CommandOperator] = None ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: @@ -22,6 +23,7 @@ def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator] cls.name = name cls.alias = alias cls.help = help + cls.usage = usage cls.parent_class = parent_class preregistered_operators.append(cls) @@ -46,7 +48,9 @@ class CommandOperator(metaclass=abc.ABCMeta): help: str """此节点的帮助信息""" - parent_class: typing.Type[CommandOperator] + usage: str = None + + parent_class: typing.Type[CommandOperator] | None = None """父节点类。标记以供管理器在初始化时编织父子关系。""" lowest_privilege: int = 0 diff --git a/pkg/command/operators/cfg.py b/pkg/command/operators/cfg.py new file mode 100644 index 00000000..b67ff3e6 --- /dev/null +++ b/pkg/command/operators/cfg.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import typing +import json + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="cfg", + help="配置项管理", + usage='!cfg <配置项> [配置值]\n!cfg all', + privilege=2 +) +class CfgOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + reply = '' + + params = context.crt_params + cfg_mgr = self.ap.cfg_mgr + + false = False + true = True + + reply_str = "" + if len(params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称')) + else: + cfg_name = params[0] + if cfg_name == 'all': + reply_str = "[bot]所有配置项:\n\n" + for cfg in cfg_mgr.data.keys(): + if not cfg.startswith('__') and not cfg == 'logging': + # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 + if isinstance(cfg_mgr.data[cfg], str): + reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg]) + elif isinstance(cfg_mgr.data[cfg], dict): + # 不进行unicode转义,并格式化 + reply_str += "{}: {}\n".format(cfg, + json.dumps(cfg_mgr.data[cfg], + ensure_ascii=False, indent=4)) + else: + reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg]) + yield entities.CommandReturn(text=reply_str) + else: + cfg_entry_path = cfg_name.split('.') + + try: + if len(params) == 1: # 未指定配置值,返回配置项值 + cfg_entry = cfg_mgr.data[cfg_entry_path[0]] + if len(cfg_entry_path) > 1: + for i in range(1, len(cfg_entry_path)): + cfg_entry = cfg_entry[cfg_entry_path[i]] + + if isinstance(cfg_entry, str): + reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry) + elif isinstance(cfg_entry, dict): + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, + json.dumps(cfg_entry, + ensure_ascii=False, indent=4)) + else: + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry) + + yield entities.CommandReturn(text=reply_str) + else: + cfg_value = " ".join(params[1:]) + + cfg_value = eval(cfg_value) + + cfg_entry = cfg_mgr.data[cfg_entry_path[0]] + if len(cfg_entry_path) > 1: + for i in range(1, len(cfg_entry_path) - 1): + cfg_entry = cfg_entry[cfg_entry_path[i]] + if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)): + cfg_entry[cfg_entry_path[-1]] = cfg_value + yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name)) + else: + # reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name))) + else: + cfg_mgr.data[cfg_entry_path[0]] = cfg_value + # reply = ["[bot]配置项{}修改成功".format(cfg_name)] + yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name)) + except KeyError: + # reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name))) + except NameError: + # reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)] + yield entities.CommandReturn(error=errors.CommandOperationError("值{}不合法(字符串需要使用双引号包裹)".format(cfg_value))) + except ValueError: + # reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name))) diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py new file mode 100644 index 00000000..17b5ed08 --- /dev/null +++ b/pkg/command/operators/cmd.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="cmd", + help='显示命令列表', + usage='!cmd\n!cmd <命令名称>' +) +class CmdOperator(operator.CommandOperator): + """命令列表 + """ + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if len(context.crt_params) == 0: + reply_str = "当前所有命令: \n\n" + + for cmd in self.ap.cmd_mgr.cmd_list: + if cmd.parent_class is None: + reply_str += f"{cmd.name}: {cmd.help}\n" + + reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助" + + yield entities.CommandReturn(text=reply_str.strip()) + + else: + cmd_name = context.crt_params[0] + + cmd = None + + for _cmd in self.ap.cmd_mgr.cmd_list: + if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): + cmd = _cmd + break + + if cmd is None: + yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) + else: + reply_str = f"{cmd.name}: {cmd.help}\n\n" + reply_str += f"使用方法: \n{cmd.usage}" + + yield entities.CommandReturn(text=reply_str.strip()) diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py new file mode 100644 index 00000000..ca7e404d --- /dev/null +++ b/pkg/command/operators/default.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="default", + help="操作情景预设", + usage='!default\n!default set <指定情景预设为默认>' +) +class DefaultOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + reply_str = "当前所有情景预设: \n\n" + + for prompt in self.ap.prompt_mgr.get_all_prompts(): + + content = "" + for msg in prompt.messages: + content += f" {msg.role}: {msg.content}" + + reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" + + reply_str += f"当前会话使用的是: {context.session.use_prompt_name}" + + yield entities.CommandReturn(text=reply_str.strip()) + + +@operator.operator_class( + name="set", + help="设置当前会话默认情景预设", + parent_class=DefaultOperator +) +class DefaultSetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py new file mode 100644 index 00000000..db865ff7 --- /dev/null +++ b/pkg/command/operators/delc.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="del", + help="删除当前会话的历史记录", + usage='!del <序号>\n!del all' +) +class DelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + delete_index = 0 + if len(context.crt_params) > 0: + try: + delete_index = int(context.crt_params[0]) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) + return + + if delete_index < 0 or delete_index >= len(context.session.conversations): + yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) + return + + # 倒序 + to_delete_index = len(context.session.conversations)-1-delete_index + + if context.session.conversations[to_delete_index] == context.session.using_conversation: + context.session.using_conversation = None + + del context.session.conversations[to_delete_index] + + yield entities.CommandReturn(text=f"已删除对话: {delete_index}") + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + + +@operator.operator_class( + name="all", + help="删除此会话的所有历史记录", + parent_class=DelOperator +) +class DelAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + context.session.conversations = [] + context.session.using_conversation = None + + yield entities.CommandReturn(text="已删除所有对话") \ No newline at end of file diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index c888831f..a4e81c35 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -5,19 +5,21 @@ from ...plugin import host as plugin_host -@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数") +@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') class FuncOperator(operator.CommandOperator): async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> AsyncGenerator[entities.CommandReturn, None]: reply_str = "当前已加载的内容函数: \n\n" index = 1 - for func in plugin_host.__callable_functions__: - reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) + for func in self.ap.tool_mgr.all_functions: + reply_str += "{}. {}{}:\n{}\n\n".format( + index, + ("(已禁用) " if not func.enable else ""), + func.name, + func.description, + ) index += 1 - yield entities.CommandReturn( - text=reply_str - ) \ No newline at end of file + yield entities.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py new file mode 100644 index 00000000..c99c2948 --- /dev/null +++ b/pkg/command/operators/help.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name='help', + help='显示帮助', + usage='!help\n!help <命令名称>' +) +class HelpOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + help = self.ap.tips_mgr.data['help_message'] + + help += '\n发送命令 !cmd 可查看命令列表' + + yield entities.CommandReturn(text=help) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py new file mode 100644 index 00000000..8e3a5231 --- /dev/null +++ b/pkg/command/operators/last.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import typing +import datetime + + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="last", + help="切换到前一个对话", + usage='!last' +) +class LastOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的上一个会话 + for index in range(len(context.session.conversations)-1, -1, -1): + if context.session.conversations[index] == context.session.using_conversation: + if index == 0: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index-1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py new file mode 100644 index 00000000..a91285e9 --- /dev/null +++ b/pkg/command/operators/list.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="list", + help="列出此会话中的所有历史对话", + usage='!list\n!list <页码>' +) +class ListOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + page = 0 + + if len(context.crt_params) > 0: + try: + page = int(context.crt_params[0]-1) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) + return + + record_per_page = 10 + + content = '' + + index = 0 + + using_conv_index = 0 + + for conv in context.session.conversations[::-1]: + time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S") + + if conv == context.session.using_conversation: + using_conv_index = index + + if index >= page * record_per_page and index < (page + 1) * record_per_page: + content += f"{index} {time_str}: {conv.messages[0].content}\n" + index += 1 + + if content == '': + content = '无' + else: + if context.session.using_conversation is None: + content += "\n当前处于新会话" + else: + content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content}" + + yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py new file mode 100644 index 00000000..8f4b5a5a --- /dev/null +++ b/pkg/command/operators/next.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="next", + help="切换到后一个对话", + usage='!next' +) +class NextOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的下一个会话 + for index in range(len(context.session.conversations)): + if context.session.conversations[index] == context.session.using_conversation: + if index == len(context.session.conversations)-1: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index+1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py new file mode 100644 index 00000000..195852ae --- /dev/null +++ b/pkg/command/operators/plugin.py @@ -0,0 +1,239 @@ +from __future__ import annotations +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors +from ...plugin import host as plugin_host +from ...utils import updater +from ...core import app + + +@operator.operator_class( + name="plugin", + help="插件操作", + usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>" +) +class PluginOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + plugin_list = plugin_host.__plugins__ + reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__)) + idx = 0 + for key in plugin_host.iter_plugins_name(): + plugin = plugin_list[key] + reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ + .format((idx+1), plugin['name'], + "[已禁用]" if not plugin['enabled'] else "", + plugin['description'], + plugin['version'], plugin['author']) + + # TODO 从元数据调远程地址 + # if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): + # remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1])) + # if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT": + # reply_str += "源码: "+remote_url+"\n" + + idx += 1 + + yield entities.CommandReturn(text=reply_str) + + +@operator.operator_class( + name="get", + help="安装插件", + privilege=2, + parent_class=PluginOperator +) +class PluginGetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) + else: + repo = context.crt_params[0] + + yield entities.CommandReturn(text="正在安装插件...") + + try: + plugin_host.install_plugin(repo) + yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e))) + + +@operator.operator_class( + name="update", + help="更新插件", + privilege=2, + parent_class=PluginOperator +) +class PluginUpdateOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + + if plugin_path_name is not None: + yield entities.CommandReturn(text="正在更新插件...") + plugin_host.update_plugin(plugin_name) + yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + +@operator.operator_class( + name="all", + help="更新所有插件", + privilege=2, + parent_class=PluginUpdateOperator +) +class PluginUpdateAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + try: + plugins = [] + + for key in plugin_host.__plugins__: + plugins.append(key) + + if plugins: + yield entities.CommandReturn(text="正在更新插件...") + updated = [] + try: + for plugin_name in plugins: + plugin_host.update_plugin(plugin_name) + updated.append(plugin_name) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated))) + else: + yield entities.CommandReturn(text="没有可更新的插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + + +@operator.operator_class( + name="del", + help="删除插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + + if plugin_path_name is not None: + yield entities.CommandReturn(text="正在删除插件...") + plugin_host.uninstall_plugin(plugin_name) + yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) + + +def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): + if plugin_name in plugin_host.__plugins__: + plugin_host.__plugins__[plugin_name]['enabled'] = new_status + + for func in ap.tool_mgr.all_functions: + if func.name.startswith(plugin_name+'-'): + func.enable = new_status + + return True + else: + return False + + +@operator.operator_class( + name="on", + help="启用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginEnableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if update_plugin_status(plugin_name, True, self.ap): + yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) + + +@operator.operator_class( + name="off", + help="禁用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDisableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if update_plugin_status(plugin_name, False, self.ap): + yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py new file mode 100644 index 00000000..29d688a6 --- /dev/null +++ b/pkg/command/operators/prompt.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="prompt", + help="查看当前对话的前文", + usage='!prompt' +) +class PromptOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + else: + reply_str = '当前对话所有内容:\n\n' + + for msg in context.session.using_conversation.messages: + reply_str += f"{msg.role}: {msg.content}\n" + + yield entities.CommandReturn(text=reply_str) \ No newline at end of file diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py new file mode 100644 index 00000000..6d930413 --- /dev/null +++ b/pkg/command/operators/resend.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="resend", + help="重发当前会话的最后一条消息", + usage='!resend' +) +class ResendOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + # 回滚到最后一条用户message前 + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandError("当前没有对话")) + else: + conv_msg = context.session.using_conversation.messages + + # 倒序一直删到最后一条用户message + while len(conv_msg) > 0 and conv_msg[-1].role != 'user': + conv_msg.pop() + + if len(conv_msg) > 0: + # 删除最后一条用户message + conv_msg.pop() + + # 不重发了,提示用户已删除就行了 + yield entities.CommandReturn(text="已删除最后一次请求记录") diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py new file mode 100644 index 00000000..5d1402ac --- /dev/null +++ b/pkg/command/operators/reset.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="reset", + help="重置当前会话", + usage='!reset' +) +class ResetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + context.session.using_conversation = None + + yield entities.CommandReturn(text="已重置当前会话") diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py new file mode 100644 index 00000000..c2235800 --- /dev/null +++ b/pkg/command/operators/version.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import typing + +from .. import operator, cmdmgr, entities, errors +from ...utils import updater + + +@operator.operator_class( + name="version", + help="显示版本信息", + usage='!version' +) +class VersionCommand(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + reply_str = f"当前版本: \n{updater.get_current_version_info()}" + + try: + if updater.is_new_version_available(): + reply_str += "\n\n有新版本可用, 使用 !update 更新" + except: + pass + + yield entities.CommandReturn(text=reply_str.strip()) \ No newline at end of file diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py index 24ff2d7e..1e3da1ad 100644 --- a/pkg/openai/requester/apis/chatcmpl.py +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -102,7 +102,7 @@ async def request( m.dict(exclude_none=True) for m in conversation.prompt.messages ] + [m.dict(exclude_none=True) for m in conversation.messages] - req_messages.append({"role": "user", "content": str(query.message_chain)}) + # req_messages.append({"role": "user", "content": str(query.message_chain)}) msg = await self._closure(req_messages, conversation) diff --git a/pkg/openai/sysprompt/sysprompt.py b/pkg/openai/sysprompt/sysprompt.py index 050f6639..5df28ee7 100644 --- a/pkg/openai/sysprompt/sysprompt.py +++ b/pkg/openai/sysprompt/sysprompt.py @@ -35,9 +35,16 @@ def get_all_prompts(self) -> list[loader.entities.Prompt]: """ return self.loader_inst.get_prompts() - async def get_prompt(self, name: str) -> loader.entities.Prompt: + async def get_prompt(self, name: str) -> loader.entities.Prompt: """获取Prompt """ for prompt in self.get_all_prompts(): if prompt.name == name: return prompt + + async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt: + """通过前缀获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name.startswith(prefix): + return prompt diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 629c2b11..889b3bb6 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -7,6 +7,7 @@ from .. import handler from ... import entities from ....core import entities as core_entities +from ....openai import entities as llm_entities class ChatMessageHandler(handler.MessageHandler): @@ -25,6 +26,13 @@ async def handle( conversation = await self.ap.sess_mgr.get_conversation(session) + conversation.messages.append( + llm_entities.Message( + role="user", + content=str(query.message_chain) + ) + ) + async for result in conversation.use_model.requester.request(query, conversation): conversation.messages.append(result) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index cf3e0740..f836a2a1 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -30,17 +30,21 @@ async def handle( mirai.Plain(str(ret.error)) ]) + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif ret.text is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(ret.text) + ]) + yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) else: - if ret.text is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(ret.text) - ]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) From b5924bb34f0e9d736cb94aa47446d080b4ff4ddc Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 18:27:48 +0800 Subject: [PATCH 08/10] =?UTF-8?q?refactor:=20=E6=B7=BB=E5=8A=A0=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E5=91=BD=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/cmdmgr.py | 2 +- pkg/command/operators/update.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 pkg/command/operators/update.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index cff5969c..73e14584 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -7,7 +7,7 @@ from ..openai.session import entities as session_entities from . import entities, operator, errors -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update class CommandManager: diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py new file mode 100644 index 00000000..db493b6a --- /dev/null +++ b/pkg/command/operators/update.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors +from ...utils import updater + + +@operator.operator_class( + name="update", + help="更新程序", + usage='!update', + privilege=2 +) +class UpdateCommand(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + try: + yield entities.CommandReturn(text="正在进行更新...") + if updater.update_all(): + yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") + else: + yield entities.CommandReturn(text="当前已是最新版本") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e))) \ No newline at end of file From 238c55a40efb28aa19d3e12d5a91d219aa12d92a Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 18:38:47 +0800 Subject: [PATCH 09/10] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E5=B7=B2?= =?UTF-8?q?=E5=BC=83=E7=94=A8=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/config/manager.py | 1 - pkg/core/app.py | 3 - pkg/core/boot.py | 6 +- pkg/openai/dprompt.py | 134 -------- pkg/openai/funcmgr.py | 46 --- pkg/openai/keymgr.py | 103 ------ pkg/openai/manager.py | 108 ------- pkg/openai/session.py | 504 ----------------------------- pkg/qqbot/cmds/__init__.py | 0 pkg/qqbot/cmds/aamgr.py | 333 ------------------- pkg/qqbot/cmds/funcs/__init__.py | 0 pkg/qqbot/cmds/funcs/draw.py | 37 --- pkg/qqbot/cmds/funcs/func.py | 32 -- pkg/qqbot/cmds/plugin/__init__.py | 0 pkg/qqbot/cmds/plugin/plugin.py | 198 ------------ pkg/qqbot/cmds/session/__init__.py | 0 pkg/qqbot/cmds/session/default.py | 71 ---- pkg/qqbot/cmds/session/del.py | 51 --- pkg/qqbot/cmds/session/delhst.py | 50 --- pkg/qqbot/cmds/session/last.py | 29 -- pkg/qqbot/cmds/session/list.py | 65 ---- pkg/qqbot/cmds/session/next.py | 29 -- pkg/qqbot/cmds/session/prompt.py | 31 -- pkg/qqbot/cmds/session/resend.py | 33 -- pkg/qqbot/cmds/session/reset.py | 35 -- pkg/qqbot/cmds/system/__init__.py | 0 pkg/qqbot/cmds/system/cconfig.py | 93 ------ pkg/qqbot/cmds/system/cmd.py | 39 --- pkg/qqbot/cmds/system/help.py | 24 -- pkg/qqbot/cmds/system/reload.py | 25 -- pkg/qqbot/cmds/system/update.py | 38 --- pkg/qqbot/cmds/system/usage.py | 33 -- pkg/qqbot/cmds/system/version.py | 27 -- pkg/qqbot/command.py | 49 --- pkg/qqbot/manager.py | 103 +----- pkg/qqbot/message.py | 134 -------- pkg/qqbot/process.py | 180 ----------- pkg/utils/__init__.py | 1 - pkg/utils/context.py | 2 - pkg/utils/reloader.py | 71 ---- pkg/utils/threadctl.py | 93 ------ 41 files changed, 3 insertions(+), 2808 deletions(-) delete mode 100644 pkg/openai/dprompt.py delete mode 100644 pkg/openai/funcmgr.py delete mode 100644 pkg/openai/keymgr.py delete mode 100644 pkg/openai/manager.py delete mode 100644 pkg/openai/session.py delete mode 100644 pkg/qqbot/cmds/__init__.py delete mode 100644 pkg/qqbot/cmds/aamgr.py delete mode 100644 pkg/qqbot/cmds/funcs/__init__.py delete mode 100644 pkg/qqbot/cmds/funcs/draw.py delete mode 100644 pkg/qqbot/cmds/funcs/func.py delete mode 100644 pkg/qqbot/cmds/plugin/__init__.py delete mode 100644 pkg/qqbot/cmds/plugin/plugin.py delete mode 100644 pkg/qqbot/cmds/session/__init__.py delete mode 100644 pkg/qqbot/cmds/session/default.py delete mode 100644 pkg/qqbot/cmds/session/del.py delete mode 100644 pkg/qqbot/cmds/session/delhst.py delete mode 100644 pkg/qqbot/cmds/session/last.py delete mode 100644 pkg/qqbot/cmds/session/list.py delete mode 100644 pkg/qqbot/cmds/session/next.py delete mode 100644 pkg/qqbot/cmds/session/prompt.py delete mode 100644 pkg/qqbot/cmds/session/resend.py delete mode 100644 pkg/qqbot/cmds/session/reset.py delete mode 100644 pkg/qqbot/cmds/system/__init__.py delete mode 100644 pkg/qqbot/cmds/system/cconfig.py delete mode 100644 pkg/qqbot/cmds/system/cmd.py delete mode 100644 pkg/qqbot/cmds/system/help.py delete mode 100644 pkg/qqbot/cmds/system/reload.py delete mode 100644 pkg/qqbot/cmds/system/update.py delete mode 100644 pkg/qqbot/cmds/system/usage.py delete mode 100644 pkg/qqbot/cmds/system/version.py delete mode 100644 pkg/qqbot/command.py delete mode 100644 pkg/qqbot/message.py delete mode 100644 pkg/qqbot/process.py delete mode 100644 pkg/utils/reloader.py delete mode 100644 pkg/utils/threadctl.py diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 7e52d7b0..8377a51a 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,7 +1,6 @@ from __future__ import annotations from . import model as file_model -from ..utils import context from .impls import pymodule, json as json_file diff --git a/pkg/core/app.py b/pkg/core/app.py index f0693957..77b7fa6d 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -4,7 +4,6 @@ import asyncio from ..qqbot import manager as qqbot_mgr -from ..openai import manager as openai_mgr from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr @@ -21,8 +20,6 @@ class Application: im_mgr: qqbot_mgr.QQBotManager = None - llm_mgr: openai_mgr.OpenAIInteract = None - cmd_mgr: cmdmgr.CommandManager = None sess_mgr: llm_session_mgr.SessionManager = None diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 2b03a153..8a07b130 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -14,7 +14,6 @@ from ..pipeline import stagemgr from ..audit import identifier from ..database import manager as db_mgr -from ..openai import manager as llm_mgr from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr @@ -107,9 +106,6 @@ async def make_app() -> app.Application: db_mgr_inst.initialize_database() ap.db_mgr = db_mgr_inst - llm_mgr_inst = llm_mgr.OpenAIInteract(ap) - ap.llm_mgr = llm_mgr_inst - cmd_mgr_inst = cmdmgr.CommandManager(ap) await cmd_mgr_inst.initialize() ap.cmd_mgr = cmd_mgr_inst @@ -130,7 +126,7 @@ async def make_app() -> app.Application: await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) + im_mgr_inst = im_mgr.QQBotManager(ap=ap) await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py deleted file mode 100644 index 247fb158..00000000 --- a/pkg/openai/dprompt.py +++ /dev/null @@ -1,134 +0,0 @@ -# 多情景预设值管理 -import json -import logging -import os - -from ..utils import context - -# __current__ = "default" -# """当前默认使用的情景预设的名称 - -# 由管理员使用`!default <名称>`命令切换 -# """ - -# __prompts_from_files__ = {} -# """从文件中读取的情景预设值""" - -# __scenario_from_files__ = {} - - -class ScenarioMode: - """情景预设模式抽象类""" - - using_prompt_name = "default" - """新session创建时使用的prompt名称""" - - prompts: dict[str, list] = {} - - def __init__(self): - logging.debug("prompts: {}".format(self.prompts)) - - def list(self) -> dict[str, list]: - """获取所有情景预设的名称及内容""" - return self.prompts - - def get_prompt(self, name: str) -> tuple[list, str]: - """获取指定情景预设的名称及内容""" - for key in self.prompts: - if key.startswith(name): - return self.prompts[key], key - raise Exception("没有找到情景预设: {}".format(name)) - - def set_using_name(self, name: str) -> str: - """设置默认情景预设""" - for key in self.prompts: - if key.startswith(name): - self.using_prompt_name = key - return key - raise Exception("没有找到情景预设: {}".format(name)) - - def get_full_name(self, name: str) -> str: - """获取完整的情景预设名称""" - for key in self.prompts: - if key.startswith(name): - return key - raise Exception("没有找到情景预设: {}".format(name)) - - def get_using_name(self) -> str: - """获取默认情景预设""" - return self.using_prompt_name - - -class NormalScenarioMode(ScenarioMode): - """普通情景预设模式""" - - def __init__(self): - config = context.get_config_manager().data - - # 加载config中的default_prompt值 - if type(config['default_prompt']) == str: - self.using_prompt_name = "default" - self.prompts = {"default": [ - { - "role": "system", - "content": config['default_prompt'] - } - ]} - - elif type(config['default_prompt']) == dict: - for key in config['default_prompt']: - self.prompts[key] = [ - { - "role": "system", - "content": config['default_prompt'][key] - } - ] - - # 从prompts/目录下的文件中载入 - # 遍历文件 - for file in os.listdir("prompts"): - with open(os.path.join("prompts", file), encoding="utf-8") as f: - self.prompts[file] = [ - { - "role": "system", - "content": f.read() - } - ] - - -class FullScenarioMode(ScenarioMode): - """完整情景预设模式""" - - def __init__(self): - """从json读取所有""" - # 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值 - for file in os.listdir("scenario"): - if file == "default-template.json": - continue - with open(os.path.join("scenario", file), encoding="utf-8") as f: - self.prompts[file] = json.load(f)["prompt"] - - super().__init__() - - -scenario_mode_mapping = {} -"""情景预设模式名称与对象的映射""" - - -def register_all(): - """注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载""" - global scenario_mode_mapping - scenario_mode_mapping = { - "normal": NormalScenarioMode(), - "full_scenario": FullScenarioMode() - } - - -def mode_inst() -> ScenarioMode: - """获取指定名称的情景预设模式对象""" - config = context.get_config_manager().data - - if config['preset_mode'] == "default": - config['preset_mode'] = "normal" - - return scenario_mode_mapping[config['preset_mode']] diff --git a/pkg/openai/funcmgr.py b/pkg/openai/funcmgr.py deleted file mode 100644 index 50932917..00000000 --- a/pkg/openai/funcmgr.py +++ /dev/null @@ -1,46 +0,0 @@ -# 封装了function calling的一些支持函数 -import logging - -from ..plugin import host - - -class ContentFunctionNotFoundError(Exception): - pass - - -def get_func_schema_list() -> list: - """从plugin包中的函数结构中获取并处理成受GPT支持的格式""" - if not host.__enable_content_functions__: - return [] - - schemas = [] - - for func in host.__callable_functions__: - if func['enabled']: - fun_cp = func.copy() - - del fun_cp['enabled'] - - schemas.append(fun_cp) - - return schemas - -def get_func(name: str) -> callable: - if name not in host.__function_inst_map__: - raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name)) - - return host.__function_inst_map__[name] - -def get_func_schema(name: str) -> dict: - for func in host.__callable_functions__: - if func['name'] == name: - return func - raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name)) - -def execute_function(name: str, kwargs: dict) -> any: - """执行函数调用""" - - logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs)) - - func = get_func(name) - return func(**kwargs) diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py deleted file mode 100644 index af560b29..00000000 --- a/pkg/openai/keymgr.py +++ /dev/null @@ -1,103 +0,0 @@ -# 此模块提供了维护api-key的各种功能 -import hashlib -import logging - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models - - -class KeysManager: - api_key = {} - """所有api-key""" - - using_key = "" - """当前使用的api-key""" - - alerted = [] - """已提示过超额的key - - 记录在此以避免重复提示 - """ - - exceeded = [] - """已超额的key - - 供自动切换功能识别 - """ - - def get_using_key(self): - return self.using_key - - def get_using_key_md5(self): - return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - - def __init__(self, api_key): - - assert type(api_key) == dict - self.api_key = api_key - # 从usage中删除未加载的api-key的记录 - # 不删了,也许会运行时添加曾经有记录的api-key - - self.auto_switch() - - def auto_switch(self) -> tuple[bool, str]: - """尝试切换api-key - - Returns: - 是否切换成功, 切换后的api-key的别名 - """ - - index = 0 - - for key_name in self.api_key: - if self.api_key[key_name] == self.using_key: - break - - index += 1 - - # 从当前key开始向后轮询 - start_index = index - index += 1 - if index >= len(self.api_key): - index = 0 - - while index != start_index: - - key_name = list(self.api_key.keys())[index] - - if self.api_key[key_name] not in self.exceeded: - self.using_key = self.api_key[key_name] - - logging.debug("使用api-key:" + key_name) - - # 触发插件事件 - args = { - "key_name": key_name, - "key_list": self.api_key.keys() - } - _ = plugin_host.emit(plugin_models.KeySwitched, **args) - - return True, key_name - - index += 1 - if index >= len(self.api_key): - index = 0 - - self.using_key = list(self.api_key.values())[start_index] - logging.debug("使用api-key:" + list(self.api_key.keys())[start_index]) - - return False, list(self.api_key.keys())[start_index] - - def add(self, key_name, key): - self.api_key[key_name] = key - - def set_current_exceeded(self): - """设置当前使用的api-key使用量超限""" - self.exceeded.append(self.using_key) - - def get_key_name(self, api_key): - """根据api-key获取其别名""" - for key_name in self.api_key: - if self.api_key[key_name] == api_key: - return key_name - return "" diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py deleted file mode 100644 index e070a29f..00000000 --- a/pkg/openai/manager.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import logging - -import openai -from openai.types import images_response - -from ..openai import keymgr -from ..utils import context -from ..audit import gatherer -from ..openai import modelmgr -from ..openai.api import model as api_model -from ..core import app - - -class OpenAIInteract: - """OpenAI 接口封装 - - 将文字接口和图片接口封装供调用方使用 - """ - - key_mgr: keymgr.KeysManager = None - - audit_mgr: gatherer.DataGatherer = None - - default_image_api_params = { - "size": "256x256", - } - - client: openai.Client = None - - def __init__(self, ap: app.Application): - - cfg= ap.cfg_mgr.data - api_key = cfg['openai_config']['api_key'] - - self.key_mgr = keymgr.KeysManager(api_key) - self.audit_mgr = gatherer.DataGatherer() - - # 配置OpenAI proxy - openai.proxies = None # 先重置,因为重载后可能需要清除proxy - if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None: - openai.proxies = { - "http": cfg['openai_config']["http_proxy"], - "https": cfg['openai_config']["http_proxy"] - } - - # 配置openai api_base - if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None: - logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy']) - openai.base_url = cfg['openai_config']["reverse_proxy"] - - - self.client = openai.Client( - api_key=self.key_mgr.get_using_key(), - base_url=openai.base_url - ) - - context.set_openai_manager(self) - - def request_completion(self, messages: list): - """请求补全接口回复= - """ - # 选择接口请求类 - config = context.get_config_manager().data - - request: api_model.RequestBase - - model: str = config['completion_api_params']['model'] - - cp_parmas = config['completion_api_params'].copy() - del cp_parmas['model'] - - request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) - - # 请求接口 - for resp in request: - - if resp['usage']['total_tokens'] > 0: - self.audit_mgr.report_text_model_usage( - model, - resp['usage']['total_tokens'] - ) - - yield resp - - def request_image(self, prompt) -> images_response.ImagesResponse: - """请求图片接口回复 - - Parameters: - prompt (str): 提示语 - - Returns: - dict: 响应 - """ - config = context.get_config_manager().data - params = config['image_api_params'] - - response = self.client.images.generate( - prompt=prompt, - n=1, - **params - ) - - self.audit_mgr.report_image_model_usage(params['size']) - - return response - diff --git a/pkg/openai/session.py b/pkg/openai/session.py deleted file mode 100644 index 19a69ea2..00000000 --- a/pkg/openai/session.py +++ /dev/null @@ -1,504 +0,0 @@ -"""主线使用的会话管理模块 - -每个人、每个群单独一个session,session内部保留了对话的上下文, -""" - -import logging -import threading -import time -import json - -from ..openai import manager as openai_manager -from ..openai import modelmgr as openai_modelmgr -from ..database import manager as database_manager -from ..utils import context as context - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models - -# 运行时保存的所有session -sessions = {} - - -class SessionOfflineStatus: - ON_GOING = 'on_going' - EXPLICITLY_CLOSED = 'explicitly_closed' - - -# 从数据加载session -def load_sessions(): - """从数据库加载sessions""" - - global sessions - - db_inst = context.get_database_manager() - - session_data = db_inst.load_valid_sessions() - - for session_name in session_data: - logging.debug('加载session: {}'.format(session_name)) - - temp_session = Session(session_name) - temp_session.name = session_name - temp_session.create_timestamp = session_data[session_name]['create_timestamp'] - temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] - - temp_session.prompt = json.loads(session_data[session_name]['prompt']) - temp_session.token_counts = json.loads(session_data[session_name]['token_counts']) - - temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \ - session_data[session_name]['default_prompt'] else [] - - sessions[session_name] = temp_session - - -# 获取指定名称的session,如果不存在则创建一个新的 -def get_session(session_name: str) -> 'Session': - global sessions - if session_name not in sessions: - sessions[session_name] = Session(session_name) - return sessions[session_name] - - -def dump_session(session_name: str): - global sessions - if session_name in sessions: - assert isinstance(sessions[session_name], Session) - sessions[session_name].persistence() - del sessions[session_name] - - -# 通用的OpenAI API交互session -# session内部保留了对话的上下文, -# 收到用户消息后,将上下文提交给OpenAI API生成回复 -class Session: - name = '' - - prompt = [] - """使用list来保存会话中的回合""" - - default_prompt = [] - """本session的默认prompt""" - - create_timestamp = 0 - """会话创建时间""" - - last_interact_timestamp = 0 - """上次交互(产生回复)时间""" - - just_switched_to_exist_session = False - - response_lock = None - - # 加锁 - def acquire_response_lock(self): - logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock)) - self.response_lock.acquire() - logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock)) - - # 释放锁 - def release_response_lock(self): - if self.response_lock.locked(): - logging.debug('{},lock release,{}'.format(self.name, self.response_lock)) - self.response_lock.release() - logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) - - # 从配置文件获取会话预设信息 - def get_default_prompt(self, use_default: str = None): - import pkg.openai.dprompt as dprompt - - if use_default is None: - use_default = dprompt.mode_inst().get_using_name() - - current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default) - return current_default_prompt - - def __init__(self, name: str): - self.name = name - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.prompt = [] - self.token_counts = [] - self.schedule() - - self.response_lock = threading.Lock() - - self.default_prompt = self.get_default_prompt() - logging.debug("prompt is: {}".format(self.default_prompt)) - - # 设定检查session最后一次对话是否超过过期时间的计时器 - def schedule(self): - threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start() - - # 检查session是否已经过期 - def expire_check_timer_loop(self, create_timestamp: int): - global sessions - while True: - time.sleep(60) - - # 不是此session已更换,退出 - if self.create_timestamp != create_timestamp or self not in sessions.values(): - return - - config = context.get_config_manager().data - if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']: - logging.info('session {} 已过期'.format(self.name)) - - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self, - 'session_expire_time': config['session_expire_time'] - } - event = plugin_host.emit(plugin_models.SessionExpired, **args) - if event.is_prevented_default(): - return - - self.reset(expired=True, schedule_new=False) - - # 删除此session - del sessions[self.name] - return - - # 请求回复 - # 这个函数是阻塞的 - def query(self, text: str=None) -> tuple[str, str, list[str]]: - """向session中添加一条消息,返回接口回复 - - Args: - text (str): 用户消息 - - Returns: - tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表) - """ - - self.last_interact_timestamp = int(time.time()) - - # 触发插件事件 - if not self.prompt: - args = { - 'session_name': self.name, - 'session': self, - 'default_prompt': self.default_prompt, - } - - event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args) - if event.is_prevented_default(): - return None, None, None - - config = context.get_config_manager().data - max_length = config['prompt_submit_length'] - - local_default_prompt = self.default_prompt.copy() - local_prompt = self.prompt.copy() - - # 触发PromptPreProcessing事件 - args = { - 'session_name': self.name, - 'default_prompt': self.default_prompt, - 'prompt': self.prompt, - 'text_message': text, - } - - event = plugin_host.emit(plugin_models.PromptPreProcessing, **args) - - if event.get_return_value('default_prompt') is not None: - local_default_prompt = event.get_return_value('default_prompt') - - if event.get_return_value('prompt') is not None: - local_prompt = event.get_return_value('prompt') - - if event.get_return_value('text_message') is not None: - text = event.get_return_value('text_message') - - # 裁剪messages到合适长度 - prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt) - - res_text = "" - - pending_msgs = [] - - total_tokens = 0 - - finish_reason: str = "" - - funcs = [] - - trace_func_calls = config['trace_function_calls'] - botmgr = context.get_qqbot_manager() - - session_name_spt: list[str] = self.name.split("_") - - pending_res_text = "" - - start_time = time.time() - - # TODO 对不起,我知道这样非常非常屎山,但我之后会重构的 - for resp in context.get_openai_manager().request_completion(prompts): - - if pending_res_text != "": - botmgr.adapter.send_message( - session_name_spt[0], - session_name_spt[1], - pending_res_text - ) - pending_res_text = "" - - finish_reason = resp['choices'][0]['finish_reason'] - - if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应 - - if not trace_func_calls: - res_text += resp['choices'][0]['message']['content'] - else: - res_text = resp['choices'][0]['message']['content'] - pending_res_text = resp['choices'][0]['message']['content'] - - total_tokens += resp['usage']['total_tokens'] - - msg = { - "role": "assistant", - "content": resp['choices'][0]['message']['content'] - } - - if 'function_call' in resp['choices'][0]['message']: - msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call']) - - pending_msgs.append(msg) - - if resp['choices'][0]['message']['type'] == 'function_call': - # self.prompt.append( - # { - # "role": "assistant", - # "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call']) - # } - # ) - if trace_func_calls: - botmgr.adapter.send_message( - session_name_spt[0], - session_name_spt[1], - "调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..." - ) - - total_tokens += resp['usage']['total_tokens'] - elif resp['choices'][0]['message']['type'] == 'function_return': - # self.prompt.append( - # { - # "role": "function", - # "name": resp['choices'][0]['message']['function_name'], - # "content": json.dumps(resp['choices'][0]['message']['content']) - # } - # ) - - # total_tokens += resp['usage']['total_tokens'] - funcs.append( - resp['choices'][0]['message']['function_name'] - ) - pass - - # 向API请求补全 - # message, total_token = pkg.utils.context.get_openai_manager().request_completion( - # prompts, - # ) - - # 成功获取,处理回复 - # res_test = message - res_ans = res_text.strip() - - # 将此次对话的双方内容加入到prompt中 - # self.prompt.append({'role': 'user', 'content': text}) - # self.prompt.append({'role': 'assistant', 'content': res_ans}) - if text: - self.prompt.append({'role': 'user', 'content': text}) - # 添加pending_msgs - self.prompt += pending_msgs - - # 向token_counts中添加本回合的token数量 - # self.token_counts.append(total_tokens-total_token_before_query) - # logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts)) - - if self.just_switched_to_exist_session: - self.just_switched_to_exist_session = False - self.set_ongoing() - - # 上报使用量数据 - session_type = session_name_spt[0] - session_id = session_name_spt[1] - - ability_provider = "QChatGPT.Text" - usage = total_tokens - model_name = context.get_config_manager().data['completion_api_params']['model'] - response_seconds = int(time.time() - start_time) - retry_times = -1 # 暂不记录 - - context.get_center_v2_api().usage.post_query_record( - session_type=session_type, - session_id=session_id, - query_ability_provider=ability_provider, - usage=usage, - model_name=model_name, - response_seconds=response_seconds, - retry_times=retry_times - ) - - return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs - - # 删除上一回合并返回上一回合的问题 - def undo(self) -> str: - self.last_interact_timestamp = int(time.time()) - - # 删除最后两个消息 - if len(self.prompt) < 2: - raise Exception('之前无对话,无法撤销') - - question = self.prompt[-2]['content'] - self.prompt = self.prompt[:-2] - self.token_counts = self.token_counts[:-1] - - # 返回上一回合的问题 - return question - - # 构建对话体 - def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]: - """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens - - :return: (新的prompt, 新的token_counts) - """ - - # 最终由三个部分组成 - # - default_prompt 情景预设固定值 - # - changable_prompts 可变部分, 此会话中的历史对话回合 - # - current_question 当前问题 - - # 包装目前的对话回合内容 - changable_prompts = [] - - use_model = context.get_config_manager().data['completion_api_params']['model'] - - ptr = len(prompt) - 1 - - # 直接从后向前扫描拼接,不管是否是整回合 - while ptr >= 0: - if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: - break - - changable_prompts.insert(0, prompt[ptr]) - - ptr -= 1 - - # 将default_prompt和changable_prompts合并 - result_prompt = default_prompt + changable_prompts - - # 添加当前问题 - if msg: - result_prompt.append( - { - 'role': 'user', - 'content': msg - } - ) - - logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4))) - - return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model) - - # 持久化session - def persistence(self): - if self.prompt == self.get_default_prompt(): - return - - db_inst = context.get_database_manager() - - name_spt = self.name.split('_') - - subject_type = name_spt[0] - subject_number = int(name_spt[1]) - - db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts)) - - # 重置session - def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False): - if self.prompt: - self.persistence() - if explicit: - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self - } - - # 此事件不支持阻止默认行为 - _ = plugin_host.emit(plugin_models.SessionExplicitReset, **args) - - context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) - - if expired: - context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - - if not persist: # 不要求保持default prompt - self.default_prompt = self.get_default_prompt(use_prompt) - self.prompt = [] - self.token_counts = [] - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.just_switched_to_exist_session = False - - # self.response_lock = threading.Lock() - - if schedule_new: - self.schedule() - - # 将本session的数据库状态设置为on_going - def set_ongoing(self): - context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) - - # 切换到上一个session - def last_session(self): - last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp) - if last_one is None: - return None - else: - self.persistence() - - self.create_timestamp = last_one['create_timestamp'] - self.last_interact_timestamp = last_one['last_interact_timestamp'] - - self.prompt = json.loads(last_one['prompt']) - self.token_counts = json.loads(last_one['token_counts']) - - self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else [] - - self.just_switched_to_exist_session = True - return self - - # 切换到下一个session - def next_session(self): - next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp) - if next_one is None: - return None - else: - self.persistence() - - self.create_timestamp = next_one['create_timestamp'] - self.last_interact_timestamp = next_one['last_interact_timestamp'] - - self.prompt = json.loads(next_one['prompt']) - self.token_counts = json.loads(next_one['token_counts']) - - self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else [] - - self.just_switched_to_exist_session = True - return self - - def list_history(self, capacity: int = 10, page: int = 0): - return context.get_database_manager().list_history(self.name, capacity, page) - - def delete_history(self, index: int) -> bool: - return context.get_database_manager().delete_history(self.name, index) - - def delete_all_history(self) -> bool: - return context.get_database_manager().delete_all_history(self.name) - - def draw_image(self, prompt: str): - return context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/cmds/__init__.py b/pkg/qqbot/cmds/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cmds/aamgr.py b/pkg/qqbot/cmds/aamgr.py deleted file mode 100644 index 6bc5c2de..00000000 --- a/pkg/qqbot/cmds/aamgr.py +++ /dev/null @@ -1,333 +0,0 @@ -import logging -import copy -import pkgutil -import traceback -import json - -import tips as tips_custom - - -__command_list__ = {} -"""命令树 - -结构: -{ - 'cmd1': { - 'description': 'cmd1 description', - 'usage': 'cmd1 usage', - 'aliases': ['cmd1 alias1', 'cmd1 alias2'], - 'privilege': 0, - 'parent': None, - 'cls': , - 'sub': [ - 'cmd1-1' - ] - }, - 'cmd1.cmd1-1: { - 'description': 'cmd1-1 description', - 'usage': 'cmd1-1 usage', - 'aliases': ['cmd1-1 alias1', 'cmd1-1 alias2'], - 'privilege': 0, - 'parent': 'cmd1', - 'cls': , - 'sub': [] - }, - 'cmd2': { - 'description': 'cmd2 description', - 'usage': 'cmd2 usage', - 'aliases': ['cmd2 alias1', 'cmd2 alias2'], - 'privilege': 0, - 'parent': None, - 'cls': , - 'sub': [ - 'cmd2-1' - ] - }, - 'cmd2.cmd2-1': { - 'description': 'cmd2-1 description', - 'usage': 'cmd2-1 usage', - 'aliases': ['cmd2-1 alias1', 'cmd2-1 alias2'], - 'privilege': 0, - 'parent': 'cmd2', - 'cls': , - 'sub': [ - 'cmd2-1-1' - ] - }, - 'cmd2.cmd2-1.cmd2-1-1': { - 'description': 'cmd2-1-1 description', - 'usage': 'cmd2-1-1 usage', - 'aliases': ['cmd2-1-1 alias1', 'cmd2-1-1 alias2'], - 'privilege': 0, - 'parent': 'cmd2.cmd2-1', - 'cls': , - 'sub': [] - }, -} -""" - -__tree_index__: dict[str, list] = {} -"""命令树索引 - -结构: -{ - 'pkg.qqbot.cmds.cmd1.CommandCmd1': 'cmd1', # 顶级命令 - 'pkg.qqbot.cmds.cmd1.CommandCmd1_1': 'cmd1.cmd1-1', # 类名: 节点路径 - 'pkg.qqbot.cmds.cmd2.CommandCmd2': 'cmd2', - 'pkg.qqbot.cmds.cmd2.CommandCmd2_1': 'cmd2.cmd2-1', - 'pkg.qqbot.cmds.cmd2.CommandCmd2_1_1': 'cmd2.cmd2-1.cmd2-1-1', -} -""" - - -class Context: - """命令执行上下文""" - command: str - """顶级命令文本""" - - crt_command: str - """当前子命令文本""" - - params: list - """完整参数列表""" - - crt_params: list - """当前子命令参数列表""" - - session_name: str - """会话名""" - - text_message: str - """命令完整文本""" - - launcher_type: str - """命令发起者类型""" - - launcher_id: int - """命令发起者ID""" - - sender_id: int - """命令发送者ID""" - - is_admin: bool - """[过时]命令发送者是否为管理员""" - - privilege: int - """命令发送者权限等级""" - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -class AbstractCommandNode: - """命令抽象类""" - - parent: type - """父命令类""" - - name: str - """命令名""" - - description: str - """命令描述""" - - usage: str - """命令用法""" - - aliases: list[str] - """命令别名""" - - privilege: int - """命令权限等级, 权限大于等于此值的用户才能执行命令""" - - @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: - """命令处理函数 - - :param ctx: 命令执行上下文 - - :return: (是否执行, 回复列表(若执行)) - - 若未执行,将自动以下一个参数查找并执行子命令 - """ - raise NotImplementedError - - @classmethod - def help(cls) -> str: - """获取命令帮助信息""" - return '命令: {}\n描述: {}\n用法: \n{}\n别名: {}\n权限: {}'.format( - cls.name, - cls.description, - cls.usage, - ', '.join(cls.aliases), - cls.privilege - ) - - @staticmethod - def register( - parent: type = None, - name: str = None, - description: str = None, - usage: str = None, - aliases: list[str] = None, - privilege: int = 0 - ): - """注册命令 - - :param cls: 命令类 - :param name: 命令名 - :param parent: 父命令类 - """ - global __command_list__, __tree_index__ - - def wrapper(cls): - cls.name = name - cls.parent = parent - cls.description = description - cls.usage = usage - cls.aliases = aliases - cls.privilege = privilege - - logging.debug("cls: {}, name: {}, parent: {}".format(cls, name, parent)) - - if parent is None: - # 顶级命令注册 - __command_list__[name] = { - 'description': cls.description, - 'usage': cls.usage, - 'aliases': cls.aliases, - 'privilege': cls.privilege, - 'parent': None, - 'cls': cls, - 'sub': [] - } - # 更新索引 - __tree_index__[cls.__module__ + '.' + cls.__name__] = name - else: - # 获取父节点名称 - path = __tree_index__[parent.__module__ + '.' + parent.__name__] - - parent_node = __command_list__[path] - # 链接父子命令 - __command_list__[path]['sub'].append(name) - # 注册子命令 - __command_list__[path + '.' + name] = { - 'description': cls.description, - 'usage': cls.usage, - 'aliases': cls.aliases, - 'privilege': cls.privilege, - 'parent': path, - 'cls': cls, - 'sub': [] - } - # 更新索引 - __tree_index__[cls.__module__ + '.' + cls.__name__] = path + '.' + name - - return cls - - return wrapper - - -class CommandPrivilegeError(Exception): - """命令权限不足或不存在异常""" - pass - - -# 传入Context对象,广搜命令树,返回执行结果 -# 若命令被处理,返回reply列表 -# 若命令未被处理,继续执行下一级命令 -# 若命令不存在,报异常 -def execute(context: Context) -> list: - """执行命令 - - :param ctx: 命令执行上下文 - - :return: 回复列表 - """ - global __command_list__ - - # 拷贝ctx - ctx: Context = copy.deepcopy(context) - - # 从树取出顶级命令 - node = __command_list__ - - path = ctx.command - - while True: - try: - node = __command_list__[path] - logging.debug('执行命令: {}'.format(path)) - - # 检查权限 - if ctx.privilege < node['privilege']: - raise CommandPrivilegeError(tips_custom.command_admin_message+"{}".format(path)) - - # 执行 - execed, reply = node['cls'].process(ctx) - if execed: - return reply - else: - # 删除crt_params第一个参数 - ctx.crt_command = ctx.crt_params.pop(0) - # 下一个path - path = path + '.' + ctx.crt_command - except KeyError: - traceback.print_exc() - raise CommandPrivilegeError(tips_custom.command_err_message+"{}".format(path)) - - -def register_all(): - """启动时调用此函数注册所有命令 - - 递归处理pkg.qqbot.cmds包下及其子包下所有模块的所有继承于AbstractCommand的类 - """ - # 模块:遍历其中的继承于AbstractCommand的类,进行注册 - # 包:递归处理包下的模块 - # 排除__开头的属性 - global __command_list__, __tree_index__ - - import pkg.qqbot.cmds - - def walk(module, prefix, path_prefix): - # 排除不处于pkg.qqbot.cmds中的包 - if not module.__name__.startswith('pkg.qqbot.cmds'): - return - - logging.debug('walk: {}, path: {}'.format(module.__name__, module.__path__)) - for item in pkgutil.iter_modules(module.__path__): - if item.name.startswith('__'): - continue - - if item.ispkg: - walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/') - else: - m = __import__(module.__name__ + '.' + item.name, fromlist=['']) - # for name, cls in inspect.getmembers(m, inspect.isclass): - # # 检查是否为命令类 - # if cls.__module__ == m.__name__ and issubclass(cls, AbstractCommandNode) and cls != AbstractCommandNode: - # cls.register(cls, cls.name, cls.parent) - - walk(pkg.qqbot.cmds, '', '') - logging.debug(__command_list__) - - -def apply_privileges(): - """读取cmdpriv.json并应用命令权限""" - # 读取内容 - json_str = "" - with open('cmdpriv.json', 'r', encoding="utf-8") as f: - json_str = f.read() - - data = json.loads(json_str) - for path, priv in data.items(): - if path == 'comment': - continue - - if path not in __command_list__: - continue - - if __command_list__[path]['privilege'] != priv: - logging.debug('应用权限: {} -> {}(default: {})'.format(path, priv, __command_list__[path]['privilege'])) - - __command_list__[path]['privilege'] = priv diff --git a/pkg/qqbot/cmds/funcs/__init__.py b/pkg/qqbot/cmds/funcs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cmds/funcs/draw.py b/pkg/qqbot/cmds/funcs/draw.py deleted file mode 100644 index 5ce25ad5..00000000 --- a/pkg/qqbot/cmds/funcs/draw.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging - -import mirai - -from .. import aamgr -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="draw", - description="使用DALL·E生成图片", - usage="!draw <图片提示语>", - aliases=[], - privilege=1 -) -class DrawCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - - reply = [] - - if len(ctx.params) == 0: - reply = ["[bot]err: 未提供图片描述文字"] - else: - session = pkg.openai.session.get_session(ctx.session_name) - - res = session.draw_image(" ".join(ctx.params)) - - logging.debug("draw_image result:{}".format(res)) - reply = [mirai.Image(url=res.data[0].url)] - config = context.get_config_manager().data - if config['include_image_description']: - reply.append(" ".join(ctx.params)) - - return True, reply diff --git a/pkg/qqbot/cmds/funcs/func.py b/pkg/qqbot/cmds/funcs/func.py deleted file mode 100644 index 61675931..00000000 --- a/pkg/qqbot/cmds/funcs/func.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging -import json - -from .. import aamgr - -@aamgr.AbstractCommandNode.register( - parent=None, - name="func", - description="管理内容函数", - usage="!func", - aliases=[], - privilege=1 -) -class FuncCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - from pkg.plugin.models import host - - reply = [] - - reply_str = "当前已加载的内容函数:\n\n" - - logging.debug("host.__callable_functions__: {}".format(json.dumps(host.__callable_functions__, indent=4))) - - index = 1 - for func in host.__callable_functions__: - reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) - index += 1 - - reply = [reply_str] - - return True, reply diff --git a/pkg/qqbot/cmds/plugin/__init__.py b/pkg/qqbot/cmds/plugin/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cmds/plugin/plugin.py b/pkg/qqbot/cmds/plugin/plugin.py deleted file mode 100644 index 5e699bba..00000000 --- a/pkg/qqbot/cmds/plugin/plugin.py +++ /dev/null @@ -1,198 +0,0 @@ -from ....plugin import host as plugin_host -from ....utils import updater -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="plugin", - description="插件管理", - usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>", - aliases=[], - privilege=1 -) -class PluginCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - plugin_list = plugin_host.__plugins__ - if len(ctx.params) == 0: - # 列出所有插件 - - reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__)) - idx = 0 - for key in plugin_host.iter_plugins_name(): - plugin = plugin_list[key] - reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ - .format((idx+1), plugin['name'], - "[已禁用]" if not plugin['enabled'] else "", - plugin['description'], - plugin['version'], plugin['author']) - - if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): - remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1])) - if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT": - reply_str += "源码: "+remote_url+"\n" - - idx += 1 - - reply = [reply_str] - return True, reply - else: - return False, [] - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="get", - description="安装插件", - usage="!plugin get <插件仓库地址>", - aliases=[], - privilege=2 -) -class PluginGetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import threading - import logging - import pkg.utils.context - - if len(ctx.crt_params) == 0: - reply = ["[bot]err: 请提供插件仓库地址"] - return True, reply - - reply = [] - def closure(): - try: - plugin_host.install_plugin(ctx.crt_params[0]) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 命令重载插件") - except Exception as e: - logging.error("插件安装失败:{}".format(e)) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e)) - - threading.Thread(target=closure, args=()).start() - reply = ["[bot]正在安装插件..."] - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="update", - description="更新指定插件或全部插件", - usage="!plugin update", - aliases=[], - privilege=2 -) -class PluginUpdateCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import threading - import logging - plugin_list = plugin_host.__plugins__ - - reply = [] - - if len(ctx.crt_params) > 0: - def closure(): - try: - import pkg.utils.context - - updated = [] - - if ctx.crt_params[0] == 'all': - for key in plugin_list: - plugin_host.update_plugin(key) - updated.append(key) - else: - plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(ctx.crt_params[0]) - - if plugin_path_name is not None: - plugin_host.update_plugin(ctx.crt_params[0]) - updated.append(ctx.crt_params[0]) - else: - raise Exception("未找到插件: {}".format(ctx.crt_params[0])) - - pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}, 请发送 !reload 重载插件".format(", ".join(updated))) - except Exception as e: - logging.error("插件更新失败:{}".format(e)) - pkg.utils.context.get_qqbot_manager().notify_admin("插件更新失败:{} 请使用 !plugin 命令确认插件名称或尝试手动更新插件".format(e)) - - reply = ["[bot]正在更新插件,请勿重复发起..."] - threading.Thread(target=closure).start() - else: - reply = ["[bot]请指定要更新的插件, 或使用 !plugin update all 更新所有插件"] - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="del", - description="删除插件", - usage="!plugin del <插件名>", - aliases=[], - privilege=2 -) -class PluginDelCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - plugin_list = plugin_host.__plugins__ - reply = [] - - if len(ctx.crt_params) < 1: - reply = ["[bot]err: 未指定插件名"] - else: - plugin_name = ctx.crt_params[0] - if plugin_name in plugin_list: - unin_path = plugin_host.uninstall_plugin(plugin_name) - reply = ["[bot]已删除插件: {} ({}), 请发送 !reload 重载插件".format(plugin_name, unin_path)] - else: - reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="on", - description="启用指定插件", - usage="!plugin on <插件名>", - aliases=[], - privilege=2 -) -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="off", - description="禁用指定插件", - usage="!plugin off <插件名>", - aliases=[], - privilege=2 -) -class PluginOnOffCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.plugin.switch as plugin_switch - - plugin_list = plugin_host.__plugins__ - reply = [] - - print(ctx.params) - new_status = ctx.params[0] == 'on' - - if len(ctx.crt_params) < 1: - reply = ["[bot]err: 未指定插件名"] - else: - plugin_name = ctx.crt_params[0] - if plugin_name in plugin_list: - plugin_list[plugin_name]['enabled'] = new_status - - for func in plugin_host.__callable_functions__: - if func['name'].startswith(plugin_name+"-"): - func['enabled'] = new_status - - plugin_switch.dump_switch() - reply = ["[bot]已{}插件: {}".format("启用" if new_status else "禁用", plugin_name)] - else: - reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)] - - return True, reply - diff --git a/pkg/qqbot/cmds/session/__init__.py b/pkg/qqbot/cmds/session/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cmds/session/default.py b/pkg/qqbot/cmds/session/default.py deleted file mode 100644 index 1a1ff756..00000000 --- a/pkg/qqbot/cmds/session/default.py +++ /dev/null @@ -1,71 +0,0 @@ -from .. import aamgr -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="default", - description="操作情景预设", - usage="!default\n!default set [指定情景预设为默认]", - aliases=[], - privilege=1 -) -class DefaultCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - config = context.get_config_manager().data - - if len(params) == 0: - # 输出目前所有情景预设 - import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode']) - - prompts = dprompt.mode_inst().list() - - for key in prompts: - pro = prompts[key] - reply_str += "名称: {}".format(key) - - for r in pro: - reply_str += "\n - [{}]: {}".format(r['role'], r['content']) - - reply_str += "\n\n" - - reply_str += "\n当前默认情景预设:{}\n".format(dprompt.mode_inst().get_using_name()) - reply_str += "请使用 !default set <情景预设名称> 来设置默认情景预设" - reply = [reply_str] - else: - return False, [] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DefaultCommand, - name="set", - description="设置默认情景预设", - usage="!default set <情景预设名称>", - aliases=[], - privilege=2 -) -class DefaultSetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - - if len(ctx.crt_params) == 0: - reply = ["[bot]err: 请指定情景预设名称"] - elif len(ctx.crt_params) > 0: - import pkg.openai.dprompt as dprompt - try: - full_name = dprompt.mode_inst().set_using_name(ctx.crt_params[0]) - reply = ["[bot]已设置默认情景预设为:{}".format(full_name)] - except Exception as e: - reply = ["[bot]err: {}".format(e)] - - return True, reply diff --git a/pkg/qqbot/cmds/session/del.py b/pkg/qqbot/cmds/session/del.py deleted file mode 100644 index 45fdc4ee..00000000 --- a/pkg/qqbot/cmds/session/del.py +++ /dev/null @@ -1,51 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="del", - description="删除当前会话的历史记录", - usage="!del <序号>\n!del all", - aliases=[], - privilege=1 -) -class DelCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - if len(params) == 0: - reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"] - else: - if params[0] == 'all': - return False, [] - elif params[0].isdigit(): - if pkg.openai.session.get_session(session_name).delete_history(int(params[0])): - reply = ["[bot]已删除历史会话 #{}".format(params[0])] - else: - reply = ["[bot]没有历史会话 #{}".format(params[0])] - else: - reply = ["[bot]参数错误, 格式: !del <序号>\n可以通过!list查看序号"] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DelCommand, - name="all", - description="删除当前会话的全部历史记录", - usage="!del all", - aliases=[], - privilege=1 -) -class DelAllCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - reply = [] - pkg.openai.session.get_session(session_name).delete_all_history() - reply = ["[bot]已删除所有历史会话"] - return True, reply diff --git a/pkg/qqbot/cmds/session/delhst.py b/pkg/qqbot/cmds/session/delhst.py deleted file mode 100644 index 31791492..00000000 --- a/pkg/qqbot/cmds/session/delhst.py +++ /dev/null @@ -1,50 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="delhst", - description="删除指定会话的所有历史记录", - usage="!delhst <会话名称>\n!delhst all", - aliases=[], - privilege=2 -) -class DelHistoryCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - import pkg.utils.context - params = ctx.params - reply = [] - if len(params) == 0: - reply = [ - "[bot]err:请输入要删除的会话名: group_<群号> 或者 person_, 或使用 !delhst all 删除所有会话的历史记录"] - else: - if params[0] == 'all': - return False, [] - else: - if pkg.utils.context.get_database_manager().delete_all_history(params[0]): - reply = ["[bot]已删除会话 {} 的所有历史记录".format(params[0])] - else: - reply = ["[bot]未找到会话 {} 的历史记录".format(params[0])] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DelHistoryCommand, - name="all", - description="删除所有会话的全部历史记录", - usage="!delhst all", - aliases=[], - privilege=2 -) -class DelAllHistoryCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.utils.context - reply = [] - pkg.utils.context.get_database_manager().delete_all_session_history() - reply = ["[bot]已删除所有会话的历史记录"] - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/last.py b/pkg/qqbot/cmds/session/last.py deleted file mode 100644 index 93459c44..00000000 --- a/pkg/qqbot/cmds/session/last.py +++ /dev/null @@ -1,29 +0,0 @@ -import datetime - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="last", - description="切换前一次对话", - usage="!last", - aliases=[], - privilege=1 -) -class LastCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - - reply = [] - result = pkg.openai.session.get_session(session_name).last_session() - if result is None: - reply = ["[bot]没有前一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)] - - return True, reply diff --git a/pkg/qqbot/cmds/session/list.py b/pkg/qqbot/cmds/session/list.py deleted file mode 100644 index fb00976d..00000000 --- a/pkg/qqbot/cmds/session/list.py +++ /dev/null @@ -1,65 +0,0 @@ -import datetime -import json - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name='list', - description='列出当前会话的所有历史记录', - usage='!list\n!list [页数]', - aliases=[], - privilege=1 -) -class ListCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - pkg.openai.session.get_session(session_name).persistence() - page = 0 - - if len(params) > 0: - try: - page = int(params[0]) - except ValueError: - pass - - results = pkg.openai.session.get_session(session_name).list_history(page=page) - if len(results) == 0: - reply_str = "[bot]第{}页没有历史会话".format(page) - else: - reply_str = "[bot]历史会话 第{}页:\n".format(page) - current = -1 - for i in range(len(results)): - # 时间(使用create_timestamp转换) 序号 部分内容 - datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) - msg = "" - - msg = json.loads(results[i]['prompt']) - - if len(msg) >= 2: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - msg[0]['content']) - else: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - "无内容") - if results[i]['create_timestamp'] == pkg.openai.session.get_session( - session_name).create_timestamp: - current = i + page * 10 - - reply_str += "\n以上信息倒序排列" - if current != -1: - reply_str += ",当前会话是 #{}\n".format(current) - else: - reply_str += ",当前处于全新会话或不在此页" - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/next.py b/pkg/qqbot/cmds/session/next.py deleted file mode 100644 index 7704acf6..00000000 --- a/pkg/qqbot/cmds/session/next.py +++ /dev/null @@ -1,29 +0,0 @@ -import datetime - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="next", - description="切换后一次对话", - usage="!next", - aliases=[], - privilege=1 -) -class NextCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - reply = [] - - result = pkg.openai.session.get_session(session_name).next_session() - if result is None: - reply = ["[bot]没有后一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/prompt.py b/pkg/qqbot/cmds/session/prompt.py deleted file mode 100644 index adb2e583..00000000 --- a/pkg/qqbot/cmds/session/prompt.py +++ /dev/null @@ -1,31 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="prompt", - description="获取当前会话的前文", - usage="!prompt", - aliases=[], - privilege=1 -) -class PromptCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - msgs = "" - session: list = pkg.openai.session.get_session(session_name).prompt - for msg in session: - if len(params) != 0 and params[0] in ['-all', '-a']: - msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) - elif len(msg['content']) > 30: - msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30]) - else: - msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content']) - reply = ["[bot]当前对话所有内容:\n{}".format(msgs)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/resend.py b/pkg/qqbot/cmds/session/resend.py deleted file mode 100644 index 941afb55..00000000 --- a/pkg/qqbot/cmds/session/resend.py +++ /dev/null @@ -1,33 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="resend", - description="重新获取上一次问题的回复", - usage="!resend", - aliases=[], - privilege=1 -) -class ResendCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - from ....openai import session as openai_session - from ....utils import context - from ....qqbot import message - - session_name = ctx.session_name - reply = [] - - session = openai_session.get_session(session_name) - to_send = session.undo() - - mgr = context.get_qqbot_manager() - - config = context.get_config_manager().data - - reply = message.process_normal_message(to_send, mgr, config, - ctx.launcher_type, ctx.launcher_id, - ctx.sender_id) - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/reset.py b/pkg/qqbot/cmds/session/reset.py deleted file mode 100644 index a93f3415..00000000 --- a/pkg/qqbot/cmds/session/reset.py +++ /dev/null @@ -1,35 +0,0 @@ -import tips as tips_custom - -from .. import aamgr -from ....openai import session -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name='reset', - description='重置当前会话', - usage='!reset', - aliases=[], - privilege=1 -) -class ResetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - params = ctx.params - session_name = ctx.session_name - - reply = "" - - if len(params) == 0: - session.get_session(session_name).reset(explicit=True) - reply = [tips_custom.command_reset_message] - else: - try: - import pkg.openai.dprompt as dprompt - session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) - reply = [tips_custom.command_reset_name_message+"{}".format(dprompt.mode_inst().get_full_name(params[0]))] - except Exception as e: - reply = ["[bot]会话重置失败:{}".format(e)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/__init__.py b/pkg/qqbot/cmds/system/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/qqbot/cmds/system/cconfig.py b/pkg/qqbot/cmds/system/cconfig.py deleted file mode 100644 index 321d68c2..00000000 --- a/pkg/qqbot/cmds/system/cconfig.py +++ /dev/null @@ -1,93 +0,0 @@ -import json - -from .. import aamgr - - -def config_operation(cmd, params): - reply = [] - import pkg.utils.context - # config = pkg.utils.context.get_config() - cfg_mgr = pkg.utils.context.get_config_manager() - - false = False - true = True - - reply_str = "" - if len(params) == 0: - reply = ["[bot]err:请输入!cmd cfg查看使用方法"] - else: - cfg_name = params[0] - if cfg_name == 'all': - reply_str = "[bot]所有配置项:\n\n" - for cfg in cfg_mgr.data.keys(): - if not cfg.startswith('__') and not cfg == 'logging': - # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 - if isinstance(cfg_mgr.data[cfg], str): - reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg]) - elif isinstance(cfg_mgr.data[cfg], dict): - # 不进行unicode转义,并格式化 - reply_str += "{}: {}\n".format(cfg, - json.dumps(cfg_mgr.data[cfg], - ensure_ascii=False, indent=4)) - else: - reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg]) - reply = [reply_str] - else: - cfg_entry_path = cfg_name.split('.') - - try: - if len(params) == 1: # 未指定配置值,返回配置项值 - cfg_entry = cfg_mgr.data[cfg_entry_path[0]] - if len(cfg_entry_path) > 1: - for i in range(1, len(cfg_entry_path)): - cfg_entry = cfg_entry[cfg_entry_path[i]] - - if isinstance(cfg_entry, str): - reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry) - elif isinstance(cfg_entry, dict): - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, - json.dumps(cfg_entry, - ensure_ascii=False, indent=4)) - else: - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry) - reply = [reply_str] - else: - cfg_value = " ".join(params[1:]) - - cfg_value = eval(cfg_value) - - cfg_entry = cfg_mgr.data[cfg_entry_path[0]] - if len(cfg_entry_path) > 1: - for i in range(1, len(cfg_entry_path) - 1): - cfg_entry = cfg_entry[cfg_entry_path[i]] - if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)): - cfg_entry[cfg_entry_path[-1]] = cfg_value - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - else: - reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] - else: - cfg_mgr.data[cfg_entry_path[0]] = cfg_value - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - except KeyError: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - except NameError: - reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)] - except ValueError: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - - return reply - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="cfg", - description="配置项管理", - usage="!cfg <配置项> [配置值]\n!cfg all", - aliases=[], - privilege=2 -) -class CfgCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - return True, config_operation(ctx.command, ctx.params) - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/cmd.py b/pkg/qqbot/cmds/system/cmd.py deleted file mode 100644 index f0a33648..00000000 --- a/pkg/qqbot/cmds/system/cmd.py +++ /dev/null @@ -1,39 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="cmd", - description="显示命令列表", - usage="!cmd\n!cmd <命令名称>", - aliases=[], - privilege=1 -) -class CmdCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - command_list = aamgr.__command_list__ - - reply = [] - - if len(ctx.params) == 0: - reply_str = "[bot]当前所有命令:\n\n" - - # 遍历顶级命令 - for key in command_list: - command = command_list[key] - if command['parent'] is None: - reply_str += "!{} - {}\n".format(key, command['description']) - - reply_str += "\n请使用 !cmd <命令名称> 来查看命令的详细信息" - - reply = [reply_str] - else: - command_name = ctx.params[0] - if command_name in command_list: - reply = [command_list[command_name]['cls'].help()] - else: - reply = ["[bot]命令 {} 不存在".format(command_name)] - - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/help.py b/pkg/qqbot/cmds/system/help.py deleted file mode 100644 index 14027b8b..00000000 --- a/pkg/qqbot/cmds/system/help.py +++ /dev/null @@ -1,24 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="help", - description="显示自定义的帮助信息", - usage="!help", - aliases=[], - privilege=1 -) -class HelpCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import tips - reply = ["[bot] "+tips.help_message + "\n请输入 !cmd 查看命令列表"] - - # 警告config.help_message过时 - import config - if hasattr(config, "help_message"): - reply[0] += "\n\n警告:config.py中的help_message已过时,不再生效,请使用tips.py中的help_message替代" - - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/reload.py b/pkg/qqbot/cmds/system/reload.py deleted file mode 100644 index 378dcef9..00000000 --- a/pkg/qqbot/cmds/system/reload.py +++ /dev/null @@ -1,25 +0,0 @@ -import threading - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="reload", - description="执行热重载", - usage="!reload", - aliases=[], - privilege=2 -) -class ReloadCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - - import pkg.utils.reloader - def reload_task(): - pkg.utils.reloader.reload_all() - - threading.Thread(target=reload_task, daemon=True).start() - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/update.py b/pkg/qqbot/cmds/system/update.py deleted file mode 100644 index d4cca3f3..00000000 --- a/pkg/qqbot/cmds/system/update.py +++ /dev/null @@ -1,38 +0,0 @@ -import threading -import traceback - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="update", - description="更新程序", - usage="!update", - aliases=[], - privilege=2 -) -class UpdateCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - import pkg.utils.updater - import pkg.utils.reloader - import pkg.utils.context - - def update_task(): - try: - if pkg.utils.updater.update_all(): - pkg.utils.context.get_qqbot_manager().notify_admin("更新完成, 请手动重启程序。") - else: - pkg.utils.context.get_qqbot_manager().notify_admin("无新版本") - except Exception as e0: - traceback.print_exc() - pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) - return - - threading.Thread(target=update_task, daemon=True).start() - - reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/usage.py b/pkg/qqbot/cmds/system/usage.py deleted file mode 100644 index 15f79b49..00000000 --- a/pkg/qqbot/cmds/system/usage.py +++ /dev/null @@ -1,33 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="usage", - description="获取使用情况", - usage="!usage", - aliases=[], - privilege=1 -) -class UsageCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import config - import pkg.utils.context - - reply = [] - - reply_str = "[bot]各api-key使用情况:\n\n" - - api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key - for key_name in api_keys: - text_length = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_text_length_of_key(api_keys[key_name]) - image_count = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_image_count_of_key(api_keys[key_name]) - reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), - int(image_count)) - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/version.py b/pkg/qqbot/cmds/system/version.py deleted file mode 100644 index 67bf3ef2..00000000 --- a/pkg/qqbot/cmds/system/version.py +++ /dev/null @@ -1,27 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="version", - description="查看版本信息", - usage="!version", - aliases=[], - privilege=1 -) -class VersionCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - import pkg.utils.updater - - reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info()) - try: - if pkg.utils.updater.is_new_version_available(): - reply_str += "\n有新版本可用,请使用命令 !update 进行更新" - except: - pass - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py deleted file mode 100644 index dba2d204..00000000 --- a/pkg/qqbot/command.py +++ /dev/null @@ -1,49 +0,0 @@ -# 命令处理模块 -import logging - -from ..qqbot.cmds import aamgr as cmdmgr - - -def process_command(session_name: str, text_message: str, mgr, config: dict, - launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: - reply = [] - try: - logging.info( - "[{}]发起命令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - cmd = text_message[1:].strip().split(' ')[0] - - params = text_message[1:].strip().split(' ')[1:] - - # 把!~开头的转换成!cfg - if cmd.startswith('~'): - params = [cmd[1:]] + params - cmd = 'cfg' - - # 包装参数 - context = cmdmgr.Context( - command=cmd, - crt_command=cmd, - params=params, - crt_params=params[:], - session_name=session_name, - text_message=text_message, - launcher_type=launcher_type, - launcher_id=launcher_id, - sender_id=sender_id, - is_admin=is_admin, - privilege=2 if is_admin else 1, # 普通用户1,管理员2 - ) - try: - reply = cmdmgr.execute(context) - except cmdmgr.CommandPrivilegeError as e: - reply = ["{}".format(e)] - - return reply - except Exception as e: - mgr.notify_admin("{}命令执行失败:{}".format(session_name, e)) - logging.exception(e) - reply = ["[bot]err:{}".format(e)] - - return reply diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 7794663a..12868b94 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -12,10 +12,7 @@ from ..openai import session as openai_session -from ..qqbot import process as processor from ..utils import context -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter from .ratelim import ratelim @@ -25,28 +22,20 @@ # 控制QQ消息输入输出的类 class QQBotManager: - retry = 3 - + adapter: msadapter.MessageSourceAdapter = None bot_account_id: int = 0 - ban_person = [] - ban_group = [] - # modern ap: app.Application = None ratelimiter: ratelim.RateLimiter = None - def __init__(self, first_time_init=True, ap: app.Application = None): - config = context.get_config_manager().data + def __init__(self, ap: app.Application = None): self.ap = ap self.ratelimiter = ratelim.RateLimiter(ap) - - self.timeout = config['process_message_timeout'] - self.retry = config['retry_times'] async def initialize(self): await self.ratelimiter.initialize() @@ -69,10 +58,6 @@ async def initialize(self): from ..utils.center import apigroup apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id) - context.set_qqbot_manager(self) - - # 注册诸事件 - # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 async def on_friend_message(event: FriendMessage): await self.ap.query_pool.add_query( @@ -144,90 +129,6 @@ async def send(self, event, msg, check_quote=True, check_at_sender=True): quote_origin=True if config['quote_origin'] and check_quote else False ) - async def common_process( - self, - launcher_type: str, - launcher_id: int, - text_message: str, - message_chain: MessageChain, - sender_id: int - ) -> mirai.MessageChain: - """ - 私聊群聊通用消息处理方法 - """ - # 检查bansess - if await self.bansess_mgr.is_banned(launcher_type, launcher_id, sender_id): - self.ap.logger.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id)) - return [] - - if mirai.Image in message_chain: - return [] - elif sender_id == self.bot_account_id: - return [] - else: - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - reply = await processor.process_message(launcher_type, launcher_id, text_message, message_chain, - sender_id) - return reply - - # TODO openai 超时处理 - except func_timeout.FunctionTimedOut: - logging.warning("{}_{}: 超时,重试中({})".format(launcher_type, launcher_id, i)) - openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock() - if "{}_{}".format(launcher_type, launcher_id) in processor.processing: - processor.processing.remove("{}_{}".format(launcher_type, launcher_id)) - failed += 1 - continue - - if failed == self.retry: - openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock() - await self.notify_admin("{} 请求超时".format("{}_{}".format(launcher_type, launcher_id))) - reply = [tips_custom.reply_message] - - # 私聊消息处理 - async def on_person_message(self, event: MessageEvent): - reply = '' - - reply = await self.common_process( - launcher_type="person", - launcher_id=event.sender.id, - text_message=str(event.message_chain), - message_chain=event.message_chain, - sender_id=event.sender.id - ) - - if reply: - await self.send(event, reply, check_quote=False, check_at_sender=False) - - # 群消息处理 - async def on_group_message(self, event: GroupMessage): - reply = '' - - text = str(event.message_chain).strip() - - rule_check_res = await self.resprule_chkr.check( - text, - event.message_chain, - event.group.id, - event.sender.id - ) - - if rule_check_res.matching: - text = str(rule_check_res.replacement).strip() - reply = await self.common_process( - launcher_type="group", - launcher_id=event.group.id, - text_message=text, - message_chain=rule_check_res.replacement, - sender_id=event.sender.id - ) - - if reply: - await self.send(event, reply) - # 通知系统管理员 async def notify_admin(self, message: str): await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py deleted file mode 100644 index beff6645..00000000 --- a/pkg/qqbot/message.py +++ /dev/null @@ -1,134 +0,0 @@ -# 普通消息处理模块 -import logging - -import openai - -from ..utils import context -from ..openai import session as openai_session - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models -import tips as tips_custom - - -def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: - """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" - config = context.get_config_manager().data - context.get_qqbot_manager().notify_admin(notify_admin) - if config['hide_exce_info_to_user']: - return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] - else: - return [set_reply] - - -def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str, - launcher_id: int, sender_id: int) -> list: - session_name = f"{launcher_type}_{launcher_id}" - logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - session = openai_session.get_session(session_name) - - unexpected_exception_times = 0 - - max_unexpected_exception_times = 3 - - reply = [] - while True: - if unexpected_exception_times >= max_unexpected_exception_times: - reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") - break - try: - prefix = "[GPT]" if config['show_prefix'] else "" - - text, finish_reason, funcs = session.query(text_message) - - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "session": session, - "prefix": prefix, - "response_text": text, - "finish_reason": finish_reason, - "funcs_called": funcs, - } - - event = plugin_host.emit(plugin_models.NormalMessageResponded, **args) - - if event.get_return_value("prefix") is not None: - prefix = event.get_return_value("prefix") - - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = [prefix + text] - - except openai.APIConnectionError as e: - err_msg = str(e) - if err_msg.__contains__('Error communicating with OpenAI'): - reply = handle_exception("{}会话调用API失败:{}\n您的网络无法访问OpenAI接口或网络代理不正常".format(session_name, e), - "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - except openai.RateLimitError as e: - logging.debug(type(e)) - logging.debug(e.error['message']) - - if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'): - # 尝试切换api-key - current_key_name = context.get_openai_manager().key_mgr.get_key_name( - context.get_openai_manager().key_mgr.using_key - ) - context.get_openai_manager().key_mgr.set_current_exceeded() - - # 触发插件事件 - args = { - 'key_name': current_key_name, - 'usage': context.get_openai_manager().audit_mgr - .get_usage(context.get_openai_manager().key_mgr.get_using_key_md5()), - 'exceeded_keys': context.get_openai_manager().key_mgr.exceeded, - } - event = plugin_host.emit(plugin_models.KeyExceeded, **args) - - if not event.is_prevented_default(): - switched, name = context.get_openai_manager().key_mgr.auto_switch() - - if not switched: - reply = handle_exception( - "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format( - current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复") - else: - openai.api_key = context.get_openai_manager().key_mgr.get_using_key() - mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) - reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] - continue - elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'): - # 重试 - unexpected_exception_times += 1 - continue - elif 'message' in e.error and e.error['message']\ - .__contains__('The server had an error while processing your request'): - # 重试 - unexpected_exception_times += 1 - continue - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), - "[bot]err:RateLimitError,请重试或联系作者,或等待修复") - except openai.BadRequestError as e: - if config['auto_reset'] and "This model's maximum context length is" in str(e): - session.reset(persist=True) - reply = [tips_custom.session_auto_reset_message] - else: - reply = handle_exception("{}API调用参数错误:{}\n".format( - session_name, e), "[bot]err:API调用参数错误,请联系管理员,或等待修复") - except openai.APIStatusError as e: - reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用,请重试或联系管理员,或等待修复") - except Exception as e: - logging.exception(e) - reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e)) - break - - return reply diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py deleted file mode 100644 index a8359be5..00000000 --- a/pkg/qqbot/process.py +++ /dev/null @@ -1,180 +0,0 @@ -# 此模块提供了消息处理的具体逻辑的接口 -from __future__ import annotations -import asyncio -import time -import traceback - -import mirai -import logging - -from ..qqbot import command, message -from ..openai import session as openai_session -from ..utils import context - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models -import tips as tips_custom -from ..core import app -# from .cntfilter import entities - -processing = [] - - -def is_admin(qq: int) -> bool: - """兼容list和int类型的管理员判断""" - config = context.get_config_manager().data - if type(config['admin_qq']) == list: - return qq in config['admin_qq'] - else: - return qq == config['admin_qq'] - - -async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, - sender_id: int) -> list: - global processing - - mgr = context.get_qqbot_manager() - - reply = [] - session_name = "{}_{}".format(launcher_type, launcher_id) - - config = context.get_config_manager().data - - if not config['wait_last_done'] and session_name in processing: - return [mirai.Plain(tips_custom.message_drop_tip)] - - # 检查是否被禁言 - if launcher_type == 'group': - is_muted = await mgr.adapter.is_muted(launcher_id) - if is_muted: - logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) - return reply - - cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message) - if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: - if cntfilter_res.console_notice: - mgr.ap.logger.info(cntfilter_res.console_notice) - if cntfilter_res.user_notice: - return [mirai.Plain(cntfilter_res.user_notice)] - else: - return [] - - openai_session.get_session(session_name).acquire_response_lock() - - text_message = text_message.strip() - - # 为强制消息延迟计时 - start_time = time.time() - - # 处理消息 - try: - - processing.append(session_name) - try: - msg_type = '' - if text_message.startswith('!') or text_message.startswith("!"): # 命令 - msg_type = 'command' - # 触发插件事件 - args = { - 'launcher_type': launcher_type, - 'launcher_id': launcher_id, - 'sender_id': sender_id, - 'command': text_message[1:].strip().split(' ')[0], - 'params': text_message[1:].strip().split(' ')[1:], - 'text_message': text_message, - 'is_admin': is_admin(sender_id), - } - event = plugin_host.emit(plugin_models.PersonCommandSent - if launcher_type == 'person' - else plugin_models.GroupCommandSent, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = command.process_command(session_name, text_message, - mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id)) - - else: # 消息 - msg_type = 'message' - # 限速丢弃检查 - if not await mgr.ratelimiter.require(launcher_type, launcher_id): - logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - - return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] - - before = time.time() - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "text_message": text_message, - } - event = plugin_host.emit(plugin_models.PersonNormalMessageReceived - if launcher_type == 'person' - else plugin_models.GroupNormalMessageReceived, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = message.process_normal_message(text_message, - mgr, config, launcher_type, launcher_id, sender_id) - - if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): - if type(reply[0]) == mirai.Plain: - reply[0] = reply[0].text - logging.info( - "回复[{}]文字消息:{}".format(session_name, - reply[0][:min(100, len(reply[0]))] + ( - "..." if len(reply[0]) > 100 else ""))) - if msg_type == 'message': - cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0]) - if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: - if cntfilter_res.console_notice: - mgr.ap.logger.info(cntfilter_res.console_notice) - if cntfilter_res.user_notice: - return [mirai.Plain(cntfilter_res.user_notice)] - else: - return [] - else: - reply = [cntfilter_res.replacement] - - reply = await mgr.longtext_pcs.check_and_process(reply[0]) - else: - logging.info("回复[{}]消息".format(session_name)) - - finally: - processing.remove(session_name) - finally: - openai_session.get_session(session_name).release_response_lock() - - # 检查延迟时间 - if config['force_delay_range'][1] == 0: - delay_time = 0 - else: - import random - - # 从延迟范围中随机取一个值(浮点) - rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1]) - - spent = time.time() - start_time - - # 如果花费时间小于延迟时间,则延迟 - delay_time = rdm - spent if rdm - spent > 0 else 0 - - # 延迟 - if delay_time > 0: - logging.info("[风控] 强制延迟{:.2f}秒(如需关闭,请到config.py修改force_delay_range字段)".format(delay_time)) - time.sleep(delay_time) - - return mirai.MessageChain(reply) diff --git a/pkg/utils/__init__.py b/pkg/utils/__init__.py index 5b1c9803..e69de29b 100644 --- a/pkg/utils/__init__.py +++ b/pkg/utils/__init__.py @@ -1 +0,0 @@ -from .threadctl import ThreadCtl \ No newline at end of file diff --git a/pkg/utils/context.py b/pkg/utils/context.py index e6a2734a..9f201b81 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,10 +1,8 @@ from __future__ import annotations import threading -from . import threadctl from ..database import manager as db_mgr -from ..openai import manager as openai_mgr from ..qqbot import manager as qqbot_mgr from ..config import manager as config_mgr from ..plugin import host as plugin_host diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py deleted file mode 100644 index eefe33b0..00000000 --- a/pkg/utils/reloader.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -import importlib -import pkgutil -import asyncio - -from . import context -from ..plugin import host as plugin_host - - -def walk(module, prefix='', path_prefix=''): - """遍历并重载所有模块""" - for item in pkgutil.iter_modules(module.__path__): - if item.ispkg: - - walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/') - else: - logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py')) - plugin_host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py' - importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=[''])) - - -def reload_all(notify=True): - # 解除bot的事件注册 - import pkg - context.get_qqbot_manager().unsubscribe_all() - # 执行关闭流程 - logging.info("执行程序关闭流程") - import main - main.stop() - - # 删除所有已注册的命令 - import pkg.qqbot.cmds.aamgr as cmdsmgr - cmdsmgr.__command_list__ = {} - cmdsmgr.__tree_index__ = {} - - # 重载所有模块 - context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded - this_context = context.context - walk(pkg) - importlib.reload(__import__("config-template")) - importlib.reload(__import__('config')) - importlib.reload(__import__('main')) - importlib.reload(__import__('banlist')) - importlib.reload(__import__('tips')) - context.context = this_context - - # 重载插件 - import plugins - walk(plugins) - - # 初始化相关文件 - main.check_file() - - # 执行启动流程 - logging.info("执行程序启动流程") - - context.get_thread_ctl().reload( - admin_pool_num=4, - user_pool_num=8 - ) - - def run_wrapper(): - asyncio.run(main.start_process(False)) - - context.get_thread_ctl().submit_sys_task( - run_wrapper - ) - - logging.info('程序启动完成') - if notify: - context.get_qqbot_manager().notify_admin("重载完成") diff --git a/pkg/utils/threadctl.py b/pkg/utils/threadctl.py deleted file mode 100644 index ab764cc3..00000000 --- a/pkg/utils/threadctl.py +++ /dev/null @@ -1,93 +0,0 @@ -import threading -import time -from concurrent.futures import ThreadPoolExecutor - - -class Pool: - """线程池结构""" - pool_num:int = None - ctl:ThreadPoolExecutor = None - task_list:list = None - task_list_lock:threading.Lock = None - monitor_type = True - - def __init__(self, pool_num): - self.pool_num = pool_num - self.ctl = ThreadPoolExecutor(max_workers = self.pool_num) - self.task_list = [] - self.task_list_lock = threading.Lock() - - def __thread_monitor__(self): - while self.monitor_type: - for t in self.task_list: - if not t.done(): - continue - try: - self.task_list.pop(self.task_list.index(t)) - except: - continue - time.sleep(1) - - -class ThreadCtl: - def __init__(self, sys_pool_num, admin_pool_num, user_pool_num): - """线程池控制类 - sys_pool_num:分配系统使用的线程池数量(>=8) - admin_pool_num:用于处理管理员消息的线程池数量(>=1) - user_pool_num:分配用于处理用户消息的线程池的数量(>=1) - """ - if sys_pool_num < 5: - raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num)) - if admin_pool_num < 1: - raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num)) - if user_pool_num < 1: - raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num)) - self.__sys_pool__ = Pool(sys_pool_num) - self.__admin_pool__ = Pool(admin_pool_num) - self.__user_pool__ = Pool(user_pool_num) - self.submit_sys_task(self.__sys_pool__.__thread_monitor__) - self.submit_sys_task(self.__admin_pool__.__thread_monitor__) - self.submit_sys_task(self.__user_pool__.__thread_monitor__) - - def __submit__(self, pool: Pool, fn, /, *args, **kwargs ): - t = pool.ctl.submit(fn, *args, **kwargs) - pool.task_list_lock.acquire() - pool.task_list.append(t) - pool.task_list_lock.release() - return t - - def submit_sys_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__sys_pool__, - fn, *args, **kwargs - ) - - def submit_admin_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__admin_pool__, - fn, *args, **kwargs - ) - - def submit_user_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__user_pool__, - fn, *args, **kwargs - ) - - def shutdown(self): - self.__user_pool__.ctl.shutdown(cancel_futures=True) - self.__user_pool__.monitor_type = False - self.__admin_pool__.ctl.shutdown(cancel_futures=True) - self.__admin_pool__.monitor_type = False - self.__sys_pool__.monitor_type = False - self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False) - - def reload(self, admin_pool_num, user_pool_num): - self.__user_pool__.ctl.shutdown(cancel_futures=True) - self.__user_pool__.monitor_type = False - self.__admin_pool__.ctl.shutdown(cancel_futures=True) - self.__admin_pool__.monitor_type = False - self.__admin_pool__ = Pool(admin_pool_num) - self.__user_pool__ = Pool(user_pool_num) - self.submit_sys_task(self.__admin_pool__.__thread_monitor__) - self.submit_sys_task(self.__user_pool__.__thread_monitor__) From d130c376f422748b353c4fdaa66604bddf4a5a4a Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 18:40:10 +0800 Subject: [PATCH 10/10] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E6=9D=83=E9=99=90=E5=90=8C=E6=AD=A5=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/update-cmdpriv-template.yml | 58 ------------------- res/scripts/generate_cmdpriv_template.py | 17 ------ 2 files changed, 75 deletions(-) delete mode 100644 .github/workflows/update-cmdpriv-template.yml delete mode 100644 res/scripts/generate_cmdpriv_template.py diff --git a/.github/workflows/update-cmdpriv-template.yml b/.github/workflows/update-cmdpriv-template.yml deleted file mode 100644 index 7493f332..00000000 --- a/.github/workflows/update-cmdpriv-template.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Update cmdpriv-template - -on: - push: - paths: - - 'pkg/qqbot/cmds/**' - pull_request: - types: [closed] - paths: - - 'pkg/qqbot/cmds/**' - -jobs: - update-cmdpriv-template: - if: github.event.pull_request.merged == true || github.event_name == 'push' - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.10.13 - - - name: Install dependencies - run: | - python -m pip install --upgrade yiri-mirai-rc openai>=1.0.0 colorlog func_timeout dulwich Pillow CallingGPT tiktoken - python -m pip install -U openai>=1.0.0 - - - name: Copy Scripts - run: | - cp res/scripts/generate_cmdpriv_template.py . - - - name: Generate Files - run: | - python main.py - - - name: Run generate_cmdpriv_template.py - run: python3 generate_cmdpriv_template.py - - - name: Check for changes in cmdpriv-template.json - id: check_changes - run: | - if git diff --name-only | grep -q "res/templates/cmdpriv-template.json"; then - echo "::set-output name=changes_detected::true" - else - echo "::set-output name=changes_detected::false" - fi - - - name: Commit changes to cmdpriv-template.json - if: steps.check_changes.outputs.changes_detected == 'true' - run: | - git config --global user.name "GitHub Actions Bot" - git config --global user.email "" - git add res/templates/cmdpriv-template.json - git commit -m "Update cmdpriv-template.json" - git push diff --git a/res/scripts/generate_cmdpriv_template.py b/res/scripts/generate_cmdpriv_template.py deleted file mode 100644 index f76f3c24..00000000 --- a/res/scripts/generate_cmdpriv_template.py +++ /dev/null @@ -1,17 +0,0 @@ -import pkg.qqbot.cmds.aamgr as cmdsmgr -import json - -# 执行命令模块的注册 -cmdsmgr.register_all() - -# 生成限权文件模板 -template: dict[str, int] = { - "comment": "以下为命令权限,请设置到cmdpriv.json中。关于此功能的说明,请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%91%BD%E4%BB%A4%E6%9D%83%E9%99%90%E6%8E%A7%E5%88%B6", -} - -for key in cmdsmgr.__command_list__: - template[key] = cmdsmgr.__command_list__[key]['privilege'] - -# 写入cmdpriv-template.json -with open('res/templates/cmdpriv-template.json', 'w') as f: - f.write(json.dumps(template, indent=4, ensure_ascii=False)) \ No newline at end of file