From 7708eaa82c7dee855f9d577a64274aecd86730a1 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 17:33:13 +0800 Subject: [PATCH 01/10] =?UTF-8?q?perf:=20=E4=B8=BA=20context.py=20?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=96=B9=E6=B3=95=E6=B7=BB=E5=8A=A0=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/utils/context.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 0da18228..b208dac8 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,6 +1,13 @@ +from __future__ import annotations + import threading from . import threadctl +from ..database import manager as db_mgr +from ..openai import manager as openai_mgr +from ..qqbot import manager as qqbot_mgr +from ..plugin import host as plugin_host + context = { 'inst': { @@ -29,59 +36,59 @@ def get_config(): return t -def set_database_manager(inst): +def set_database_manager(inst: db_mgr.DatabaseManager): context_lock.acquire() context['inst']['database.manager.DatabaseManager'] = inst context_lock.release() -def get_database_manager(): +def get_database_manager() -> db_mgr.DatabaseManager: context_lock.acquire() t = context['inst']['database.manager.DatabaseManager'] context_lock.release() return t -def set_openai_manager(inst): +def set_openai_manager(inst: openai_mgr.OpenAIInteract): context_lock.acquire() context['inst']['openai.manager.OpenAIInteract'] = inst context_lock.release() -def get_openai_manager(): +def get_openai_manager() -> openai_mgr.OpenAIInteract: context_lock.acquire() t = context['inst']['openai.manager.OpenAIInteract'] context_lock.release() return t -def set_qqbot_manager(inst): +def set_qqbot_manager(inst: qqbot_mgr.QQBotManager): context_lock.acquire() context['inst']['qqbot.manager.QQBotManager'] = inst context_lock.release() -def get_qqbot_manager(): +def get_qqbot_manager() -> qqbot_mgr.QQBotManager: context_lock.acquire() t = context['inst']['qqbot.manager.QQBotManager'] context_lock.release() return t -def set_plugin_host(inst): +def set_plugin_host(inst: plugin_host.PluginHost): context_lock.acquire() context['plugin_host'] = inst context_lock.release() -def get_plugin_host(): +def get_plugin_host() -> plugin_host.PluginHost: context_lock.acquire() t = context['plugin_host'] context_lock.release() return t -def set_thread_ctl(inst): +def set_thread_ctl(inst: threadctl.ThreadCtl): context_lock.acquire() context['pool_ctl'] = inst context_lock.release() From 419354cb07f3f71cdce5c7e9165ba2a9b4a5d416 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 17:42:25 +0800 Subject: [PATCH 02/10] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=94=A8?= =?UTF-8?q?=E4=BA=8E=E8=A6=86=E7=9B=96=E7=8E=87=E6=B5=8B=E8=AF=95=E7=9A=84?= =?UTF-8?q?=E9=80=80=E5=87=BA=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 1a96f1a4..36865e13 100644 --- a/main.py +++ b/main.py @@ -463,12 +463,19 @@ def main(): except: stop() pkg.utils.context.get_thread_ctl().shutdown() - import platform - if platform.system() == 'Windows': - cmd = "taskkill /F /PID {}".format(os.getpid()) - elif platform.system() in ['Linux', 'Darwin']: - cmd = "kill -9 {}".format(os.getpid()) - os.system(cmd) + + launch_args = sys.argv.copy() + + if "--cov-report" not in launch_args: + import platform + if platform.system() == 'Windows': + cmd = "taskkill /F /PID {}".format(os.getpid()) + elif platform.system() in ['Linux', 'Darwin']: + cmd = "kill -9 {}".format(os.getpid()) + os.system(cmd) + else: + print("正常退出以生成覆盖率报告") + sys.exit(0) if __name__ == '__main__': From d1dff6dedd8279ab9dfb18b44cf18b8a13d321e7 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 21:53:35 +0800 Subject: [PATCH 03/10] =?UTF-8?q?feat(main.py):=20=E5=B0=86=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E5=8A=A0=E8=BD=BD=E6=B5=81=E7=A8=8B=E6=94=BE=E5=88=B0?= =?UTF-8?q?start=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- config-template.py | 13 ------------- main.py | 19 +++++++++---------- pkg/utils/reloader.py | 7 +++---- 4 files changed, 13 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 88cabd2a..9d4595de 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -config.py +/config.py .idea/ __pycache__/ database.db diff --git a/config-template.py b/config-template.py index 1389c6ae..4ddc4135 100644 --- a/config-template.py +++ b/config-template.py @@ -322,19 +322,6 @@ # 设置为False时,向用户及管理员发送错误详细信息 hide_exce_info_to_user = False -# 线程池相关配置 -# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃 -# 如果你不清楚该参数的意义,请不要更改 -# 程序运行本身线程池,无代码层面修改请勿更改 -sys_pool_num = 8 - -# 执行管理员请求和指令的线程池并行线程数量,一般和管理员数量相等 -admin_pool_num = 4 - -# 执行用户请求和指令的线程池并行线程数量 -# 如需要更高的并发,可以增大该值 -user_pool_num = 8 - # 每个会话的过期时间,单位为秒 # 默认值20分钟 session_expire_time = 1200 diff --git a/main.py b/main.py index 36865e13..6f83ba24 100644 --- a/main.py +++ b/main.py @@ -171,6 +171,12 @@ def start(first_time_init=False): global known_exception_caught import pkg.utils.context + # 加载配置 + load_config() + + # 检查tips模块 + complete_tips() + config = pkg.utils.context.get_config() # 更新openai库到最新版本 if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: @@ -420,19 +426,12 @@ def main(): init_runtime_log_file() pkg.utils.context.context['logger_handler'] = reset_logging() - # 加载配置 - load_config() - config = pkg.utils.context.get_config() - - # 检查tips模块 - complete_tips() - # 配置线程池 from pkg.utils import ThreadCtl thread_ctl = ThreadCtl( - sys_pool_num=config.sys_pool_num, - admin_pool_num=config.admin_pool_num, - user_pool_num=config.user_pool_num + sys_pool_num=8, + admin_pool_num=4, + user_pool_num=8 ) # 存进上下文 pkg.utils.context.set_thread_ctl(thread_ctl) diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index a9f7445b..f08e87d7 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -52,11 +52,10 @@ def reload_all(notify=True): # 执行启动流程 logging.info("执行程序启动流程") - main.load_config() - main.complete_tips() + context.get_thread_ctl().reload( - admin_pool_num=context.get_config().admin_pool_num, - user_pool_num=context.get_config().user_pool_num + admin_pool_num=4, + user_pool_num=8 ) context.get_thread_ctl().submit_sys_task( main.start, From e396ba46495573a32f678222bca836cbbd6bd1ee Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Sun, 26 Nov 2023 13:54:00 +0000 Subject: [PATCH 04/10] Update override-all.json --- override-all.json | 3 --- 1 file changed, 3 deletions(-) diff --git a/override-all.json b/override-all.json index d2907595..75f65320 100644 --- a/override-all.json +++ b/override-all.json @@ -78,9 +78,6 @@ "font_path": "", "retry_times": 3, "hide_exce_info_to_user": false, - "sys_pool_num": 8, - "admin_pool_num": 4, - "user_pool_num": 8, "session_expire_time": 1200, "rate_limitation": { "default": 60 From 5f07ff8145716e358405818b1de95eeacfb5c379 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 22:19:36 +0800 Subject: [PATCH 05/10] =?UTF-8?q?refactor:=20=E5=90=AF=E5=8A=A8=E6=B5=81?= =?UTF-8?q?=E7=A8=8B=E7=8E=B0=E5=9C=A8=E5=BC=82=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 19 ++++++++++++++++--- pkg/utils/reloader.py | 8 ++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 6f83ba24..02855e3f 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,13 @@ import logging import sys import traceback +import asyncio sys.path.append(".") from pkg.utils.log import init_runtime_log_file, reset_logging +from pkg.config import manager as config_mgr +from pkg.config.impls import pymodule as pymodule_cfg def check_file(): @@ -165,7 +168,7 @@ def complete_tips(): time.sleep(3) -def start(first_time_init=False): +async def start_process(first_time_init=False): """启动流程,reload之后会被执行""" global known_exception_caught @@ -174,6 +177,14 @@ def start(first_time_init=False): # 加载配置 load_config() + cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile( + 'config.py', + 'config-template.py' + ) + await config_mgr.ConfigManager(cfg_inst).load_config() + + # TODO: override config + # 检查tips模块 complete_tips() @@ -450,9 +461,11 @@ def main(): # 关闭urllib的http警告 requests.packages.urllib3.disable_warnings(InsecureRequestWarning) + def run_wrapper(): + asyncio.run(start_process(True)) + pkg.utils.context.get_thread_ctl().submit_sys_task( - start, - True + run_wrapper ) # 主线程循环 diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index f08e87d7..e0029af6 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -1,6 +1,7 @@ import logging import importlib import pkgutil +import asyncio from . import context from ..plugin import host as plugin_host @@ -57,9 +58,12 @@ def reload_all(notify=True): admin_pool_num=4, user_pool_num=8 ) + + def run_wrapper(): + asyncio.run(main.start_process(False)) + context.get_thread_ctl().submit_sys_task( - main.start, - False + run_wrapper ) logging.info('程序启动完成') From 26e4215054a4f1a7f10f711311d9214e907219f7 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 22:25:54 +0800 Subject: [PATCH 06/10] =?UTF-8?q?feat:=20=E6=96=B0=E7=9A=84override?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 02855e3f..c855bfbd 100644 --- a/main.py +++ b/main.py @@ -116,6 +116,23 @@ def override_config(): logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) +def override_config_manager(): + config = pkg.utils.context.get_config_manager().data + + if os.path.exists("override.json") and use_override: + override_json = json.load(open("override.json", "r", encoding="utf-8")) + overrided = [] + for key in override_json: + if key in config: + config[key] = override_json[key] + # logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) + overrided.append(key) + else: + logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) + if len(overrided) > 0: + logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) + + # 临时函数,用于加载config和上下文,未来统一放在config类 def load_config(): logging.info("检查config模块完整性.") @@ -183,7 +200,7 @@ async def start_process(first_time_init=False): ) await config_mgr.ConfigManager(cfg_inst).load_config() - # TODO: override config + override_config_manager() # 检查tips模块 complete_tips() From db2e3660145eeb0da9c3697b46f4d0650502255b Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 22:46:27 +0800 Subject: [PATCH 07/10] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=87=E4=BB=B6=E7=AE=A1=E7=90=86=E5=99=A8=E5=B9=B6?= =?UTF-8?q?=E9=80=82=E9=85=8Dmain.py=E4=B8=AD=E7=9A=84=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 37 +++++++++++---------- pkg/config/__init__.py | 0 pkg/config/impls/pymodule.py | 62 ++++++++++++++++++++++++++++++++++++ pkg/config/manager.py | 23 +++++++++++++ pkg/config/model.py | 27 ++++++++++++++++ pkg/utils/context.py | 15 +++++++++ 6 files changed, 145 insertions(+), 19 deletions(-) create mode 100644 pkg/config/__init__.py create mode 100644 pkg/config/impls/pymodule.py create mode 100644 pkg/config/manager.py create mode 100644 pkg/config/model.py diff --git a/main.py b/main.py index c855bfbd..55fea88b 100644 --- a/main.py +++ b/main.py @@ -205,11 +205,11 @@ async def start_process(first_time_init=False): # 检查tips模块 complete_tips() - config = pkg.utils.context.get_config() + cfg = pkg.utils.context.get_config_manager().data # 更新openai库到最新版本 - if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: + if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']: print("正在更新依赖库,请等待...") - if not hasattr(config, 'upgrade_dependencies'): + if 'upgrade_dependencies' not in cfg: print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False") else: print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False") @@ -226,11 +226,11 @@ async def start_process(first_time_init=False): pkg.utils.context.context['logger_handler'] = sh # 检查是否设置了管理员 - if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): + if cfg['admin_qq'] == 0: # logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") while True: try: - config.admin_qq = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: ")) + cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: ")) # 写入到文件 # 读取文件 @@ -238,7 +238,7 @@ async def start_process(first_time_init=False): with open("config.py", "r", encoding="utf-8") as f: config_file_str = f.read() # 替换 - config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(config.admin_qq)) + config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq'])) # 写入 with open("config.py", "w", encoding="utf-8") as f: f.write(config_file_str) @@ -267,23 +267,23 @@ async def start_process(first_time_init=False): # 配置OpenAI proxy import openai openai.proxies = None # 先重置,因为重载后可能需要清除proxy - if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None: + if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None: openai.proxies = { - "http": config.openai_config["http_proxy"], - "https": config.openai_config["http_proxy"] + "http": cfg['openai_config']["http_proxy"], + "https": cfg['openai_config']["http_proxy"] } # 配置openai api_base - if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None: - logging.debug("设置反向代理: "+config.openai_config['reverse_proxy']) - openai.base_url = config.openai_config["reverse_proxy"] + if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None: + logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy']) + openai.base_url = cfg['openai_config']["reverse_proxy"] # 主启动流程 database = pkg.database.manager.DatabaseManager() database.initialize_database() - openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key']) + openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key']) # 加载所有未超时的session pkg.openai.session.load_sessions() @@ -372,13 +372,12 @@ def run_bot_wrapper(): if first_time_init: if not known_exception_caught: - import config - if config.msg_source_adapter == "yirimirai": - logging.info("QQ: {}, MAH: {}".format(config.mirai_http_api_config['qq'], config.mirai_http_api_config['host']+":"+str(config.mirai_http_api_config['port']))) + if cfg['msg_source_adapter'] == "yirimirai": + logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port']))) logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): ' 'https://github.com/RockChinQ/QChatGPT/issues/37') - elif config.msg_source_adapter == 'nakuru': - logging.info("host: {}, port: {}, http_port: {}".format(config.nakuru_config['host'], config.nakuru_config['port'], config.nakuru_config['http_port'])) + elif cfg['msg_source_adapter'] == 'nakuru': + logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port'])) logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确') else: sys.exit(1) @@ -386,7 +385,7 @@ def run_bot_wrapper(): logging.info('热重载完成') # 发送赞赏码 - if config.encourage_sponsor_at_start \ + if cfg['encourage_sponsor_at_start'] \ and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048: logging.info("发送赞赏码") diff --git a/pkg/config/__init__.py b/pkg/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py new file mode 100644 index 00000000..691082de --- /dev/null +++ b/pkg/config/impls/pymodule.py @@ -0,0 +1,62 @@ +import os +import shutil +import importlib +import logging + +from .. import model as file_model + + +class PythonModuleConfigFile(file_model.ConfigFile): + """Python模块配置文件""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + def __init__(self, config_file_name: str, template_file_name: str) -> None: + self.config_file_name = config_file_name + self.template_file_name = template_file_name + + def exists(self) -> bool: + return os.path.exists(self.config_file_name) + + async def create(self): + shutil.copyfile(self.template_file_name, self.config_file_name) + + async def load(self) -> dict: + module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] + module = importlib.import_module(module_name) + + cfg = {} + + allowed_types = (int, float, str, bool, list, dict) + + for key in dir(module): + if key.startswith('__'): + continue + + if not isinstance(getattr(module, key), allowed_types): + continue + + cfg[key] = getattr(module, key) + + # 从模板模块文件中进行补全 + module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] + module = importlib.import_module(module_name) + + for key in dir(module): + if key.startswith('__'): + continue + + if not isinstance(getattr(module, key), allowed_types): + continue + + if key not in cfg: + cfg[key] = getattr(module, key) + + return cfg + + async def save(self, data: dict): + logging.warning('Python模块配置文件不支持保存') diff --git a/pkg/config/manager.py b/pkg/config/manager.py new file mode 100644 index 00000000..53a6b099 --- /dev/null +++ b/pkg/config/manager.py @@ -0,0 +1,23 @@ +from . import model as file_model +from ..utils import context + + +class ConfigManager: + """配置文件管理器""" + + file: file_model.ConfigFile = None + """配置文件实例""" + + data: dict = None + """配置数据""" + + def __init__(self, cfg_file: file_model.ConfigFile) -> None: + self.file = cfg_file + self.data = {} + context.set_config_manager(self) + + async def load_config(self): + self.data = await self.file.load() + + async def dump_config(self): + await self.file.save(self.data) diff --git a/pkg/config/model.py b/pkg/config/model.py new file mode 100644 index 00000000..e72371ff --- /dev/null +++ b/pkg/config/model.py @@ -0,0 +1,27 @@ +import abc + + +class ConfigFile(metaclass=abc.ABCMeta): + """配置文件抽象类""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + @abc.abstractmethod + def exists(self) -> bool: + pass + + @abc.abstractmethod + async def create(self): + pass + + @abc.abstractmethod + async def load(self) -> dict: + pass + + @abc.abstractmethod + async def save(self, data: dict): + pass diff --git a/pkg/utils/context.py b/pkg/utils/context.py index b208dac8..e26c702b 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -6,6 +6,7 @@ from ..database import manager as db_mgr from ..openai import manager as openai_mgr from ..qqbot import manager as qqbot_mgr +from ..config import manager as config_mgr from ..plugin import host as plugin_host @@ -14,6 +15,7 @@ 'database.manager.DatabaseManager': None, 'openai.manager.OpenAIInteract': None, 'qqbot.manager.QQBotManager': None, + 'config.manager.ConfigManager': None, }, 'pool_ctl': None, 'logger_handler': None, @@ -75,6 +77,19 @@ def get_qqbot_manager() -> qqbot_mgr.QQBotManager: return t +def set_config_manager(inst: config_mgr.ConfigManager): + context_lock.acquire() + context['inst']['config.manager.ConfigManager'] = inst + context_lock.release() + + +def get_config_manager() -> config_mgr.ConfigManager: + context_lock.acquire() + t = context['inst']['config.manager.ConfigManager'] + context_lock.release() + return t + + def set_plugin_host(inst: plugin_host.PluginHost): context_lock.acquire() context['plugin_host'] = inst From 549a7eff7fdbbbe5ad7ccba0864df20d30a88dd3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 23:04:14 +0800 Subject: [PATCH 08/10] =?UTF-8?q?refactor(qqbot):=20=E9=80=82=E9=85=8D?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=AE=A1=E7=90=86=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/adapter.py | 1 + pkg/qqbot/manager.py | 106 ++++++++++++++++++++++--------------------- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/pkg/qqbot/adapter.py b/pkg/qqbot/adapter.py index 0d00915d..784d8ae3 100644 --- a/pkg/qqbot/adapter.py +++ b/pkg/qqbot/adapter.py @@ -5,6 +5,7 @@ class MessageSourceAdapter: + bot_account_id: int def __init__(self, config: dict): pass diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 922de441..588d6621 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -2,7 +2,7 @@ import os import logging -from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ +from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ FriendMessage, Image, MessageChain, Plain import func_timeout @@ -19,16 +19,16 @@ # 检查消息是否符合泛响应匹配机制 def check_response_rule(group_id:int, text: str): - config = context.get_config() + config = context.get_config_manager().data - rules = config.response_rules + rules = config['response_rules'] # 检查是否有特定规则 - if 'prefix' not in config.response_rules: - if str(group_id) in config.response_rules: - rules = config.response_rules[str(group_id)] + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + rules = config['response_rules'][str(group_id)] else: - rules = config.response_rules['default'] + rules = config['response_rules']['default'] # 检查前缀匹配 if 'prefix' in rules: @@ -48,16 +48,16 @@ def check_response_rule(group_id:int, text: str): def response_at(group_id: int): - config = context.get_config() + config = context.get_config_manager().data - use_response_rule = config.response_rules + use_response_rule = config['response_rules'] # 检查是否有特定规则 - if 'prefix' not in config.response_rules: - if str(group_id) in config.response_rules: - use_response_rule = config.response_rules[str(group_id)] + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + use_response_rule = config['response_rules'][str(group_id)] else: - use_response_rule = config.response_rules['default'] + use_response_rule = config['response_rules']['default'] if 'at' not in use_response_rule: return True @@ -66,16 +66,16 @@ def response_at(group_id: int): def random_responding(group_id): - config = context.get_config() + config = context.get_config_manager().data - use_response_rule = config.response_rules + use_response_rule = config['response_rules'] # 检查是否有特定规则 - if 'prefix' not in config.response_rules: - if str(group_id) in config.response_rules: - use_response_rule = config.response_rules[str(group_id)] + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + use_response_rule = config['response_rules'][str(group_id)] else: - use_response_rule = config.response_rules['default'] + use_response_rule = config['response_rules']['default'] if 'random_rate' in use_response_rule: import random @@ -102,25 +102,25 @@ class QQBotManager: ban_group = [] def __init__(self, first_time_init=True): - import config + config = context.get_config_manager().data - self.timeout = config.process_message_timeout - self.retry = config.retry_times + self.timeout = config['process_message_timeout'] + self.retry = config['retry_times'] # 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用 # 故只在第一次初始化时创建bot对象,重载之后使用原bot对象 # 因此,bot的配置不支持热重载 if first_time_init: - logging.debug("Use adapter:" + config.msg_source_adapter) - if config.msg_source_adapter == 'yirimirai': + logging.debug("Use adapter:" + config['msg_source_adapter']) + if config['msg_source_adapter'] == 'yirimirai': from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter - mirai_http_api_config = config.mirai_http_api_config - self.bot_account_id = config.mirai_http_api_config['qq'] + mirai_http_api_config = config['mirai_http_api_config'] + self.bot_account_id = config['mirai_http_api_config']['qq'] self.adapter = YiriMiraiAdapter(mirai_http_api_config) - elif config.msg_source_adapter == 'nakuru': + elif config['msg_source_adapter'] == 'nakuru': from pkg.qqbot.sources.nakuru import NakuruProjectAdapter - self.adapter = NakuruProjectAdapter(config.nakuru_config) + self.adapter = NakuruProjectAdapter(config['nakuru_config']) self.bot_account_id = self.adapter.bot_account_id else: self.adapter = context.get_qqbot_manager().adapter @@ -176,7 +176,7 @@ def stranger_message_handler(): stranger_message_handler, ) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 - if config.msg_source_adapter == 'yirimirai': + if config['msg_source_adapter'] == 'yirimirai': self.adapter.register_listener( StrangerMessage, on_stranger_message @@ -213,12 +213,11 @@ def unsubscribe_all(): 用于在热重载流程中卸载所有事件处理器 """ - import config self.adapter.unregister_listener( FriendMessage, on_friend_message ) - if config.msg_source_adapter == 'yirimirai': + if config['msg_source_adapter'] == 'yirimirai': self.adapter.unregister_listener( StrangerMessage, on_stranger_message @@ -243,10 +242,10 @@ def unsubscribe_all(): if hasattr(banlist, "enable_group"): self.enable_group = banlist.enable_group - config = context.get_config() + config = context.get_config_manager().data if os.path.exists("sensitive.json") \ - and config.sensitive_word_filter is not None \ - and config.sensitive_word_filter: + and config['sensitive_word_filter'] is not None \ + and config['sensitive_word_filter']: with open("sensitive.json", "r", encoding="utf-8") as f: sensitive_json = json.load(f) self.reply_filter = qqbot_filter.ReplyFilter( @@ -258,16 +257,16 @@ def unsubscribe_all(): self.reply_filter = qqbot_filter.ReplyFilter([]) def send(self, event, msg, check_quote=True, check_at_sender=True): - config = context.get_config() + config = context.get_config_manager().data - if check_at_sender and config.at_sender: + if check_at_sender and config['at_sender']: msg.insert( 0, Plain(" \n") ) # 当回复的正文中包含换行时,quote可能会自带at,此时就不再单独添加at,只添加换行 - if "\n" not in str(msg[1]) or config.msg_source_adapter == 'nakuru': + if "\n" not in str(msg[1]) or config['msg_source_adapter'] == 'nakuru': msg.insert( 0, At( @@ -278,14 +277,15 @@ def send(self, event, msg, check_quote=True, check_at_sender=True): self.adapter.reply_message( event, msg, - quote_origin=True if config.quote_origin and check_quote else False + quote_origin=True if config['quote_origin'] and check_quote else False ) # 私聊消息处理 def on_person_message(self, event: MessageEvent): - import config reply = '' + config = context.get_config_manager().data + if not self.enable_private: logging.debug("已在banlist.py中禁用所有私聊") elif event.sender.id == self.bot_account_id: @@ -299,7 +299,7 @@ def on_person_message(self, event: MessageEvent): for i in range(self.retry): try: - @func_timeout.func_set_timeout(config.process_message_timeout) + @func_timeout.func_set_timeout(config['process_message_timeout']) def time_ctrl_wrapper(): reply = processor.process_message('person', event.sender.id, str(event.message_chain), event.message_chain, @@ -326,8 +326,10 @@ def time_ctrl_wrapper(): # 群消息处理 def on_group_message(self, event: GroupMessage): - import config reply = '' + + config = context.get_config_manager().data + def process(text=None) -> str: replys = "" if At(self.bot_account_id) in event.message_chain: @@ -337,7 +339,7 @@ def process(text=None) -> str: failed = 0 for i in range(self.retry): try: - @func_timeout.func_set_timeout(config.process_message_timeout) + @func_timeout.func_set_timeout(config['process_message_timeout']) def time_ctrl_wrapper(): replys = processor.process_message('group', event.group.id, str(event.message_chain).strip() if text is None else text, @@ -385,17 +387,17 @@ def time_ctrl_wrapper(): # 通知系统管理员 def notify_admin(self, message: str): - config = context.get_config() - if config.admin_qq != 0 and config.admin_qq != []: + config = context.get_config_manager().data + if config['admin_qq'] != 0 and config['admin_qq'] != []: logging.info("通知管理员:{}".format(message)) - if type(config.admin_qq) == int: + if type(config['admin_qq']) == int: self.adapter.send_message( "person", - config.admin_qq, + config['admin_qq'], MessageChain([Plain("[bot]{}".format(message))]) ) else: - for adm in config.admin_qq: + for adm in config['admin_qq']: self.adapter.send_message( "person", adm, @@ -403,17 +405,17 @@ def notify_admin(self, message: str): ) def notify_admin_message_chain(self, message): - config = context.get_config() - if config.admin_qq != 0 and config.admin_qq != []: + config = context.get_config_manager().data + if config['admin_qq'] != 0 and config['admin_qq'] != []: logging.info("通知管理员:{}".format(message)) - if type(config.admin_qq) == int: + if type(config['admin_qq']) == int: self.adapter.send_message( "person", - config.admin_qq, + config['admin_qq'], message ) else: - for adm in config.admin_qq: + for adm in config['admin_qq']: self.adapter.send_message( "person", adm, From 3e17bbb90faddd9360ad49166b999b6c5fb02d43 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 23:58:06 +0800 Subject: [PATCH 09/10] =?UTF-8?q?refactor:=20=E9=80=82=E9=85=8D=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E7=AE=A1=E7=90=86=E5=99=A8=E8=AF=BB=E5=8F=96=E6=96=B9?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 +++ pkg/audit/gatherer.py | 6 ++-- pkg/database/manager.py | 4 +-- pkg/openai/api/model.py | 7 +++-- pkg/openai/dprompt.py | 23 ++++++++------ pkg/openai/manager.py | 10 +++--- pkg/openai/session.py | 18 +++++------ pkg/qqbot/blob.py | 13 +++++--- pkg/qqbot/cmds/funcs/draw.py | 6 ++-- pkg/qqbot/cmds/session/default.py | 7 +++-- pkg/qqbot/command.py | 2 +- pkg/qqbot/filter.py | 13 +++++--- pkg/qqbot/ignore.py | 12 +++++--- pkg/qqbot/message.py | 10 +++--- pkg/qqbot/process.py | 26 ++++++++-------- pkg/qqbot/ratelimit.py | 17 ++++++----- pkg/qqbot/sources/nakuru.py | 11 ++++--- pkg/utils/log.py | 11 +++++-- pkg/utils/network.py | 10 +++--- pkg/utils/text2img.py | 51 +++++++++++++++++-------------- 20 files changed, 148 insertions(+), 113 deletions(-) diff --git a/main.py b/main.py index 55fea88b..402a76b6 100644 --- a/main.py +++ b/main.py @@ -218,6 +218,10 @@ async def start_process(first_time_init=False): except Exception as e: print("更新openai库失败:{}, 请忽略或自行更新".format(e)) + # 初始化文字转图片 + from pkg.utils import text2img + text2img.initialize() + known_exception_caught = False try: try: diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index acf6b368..2dc2c560 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -47,10 +47,10 @@ def report_to_server(self, subservice_name: str, count: int): def thread_func(): try: - config = context.get_config() - if not config.report_usage: + config = context.get_config_manager().data + if not config['report_usage']: return - res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter)) + res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config['msg_source_adapter'])) if res.status_code != 200 or res.text != "ok": logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text)) except: diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 15a00ae9..f410b418 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -144,11 +144,11 @@ def set_session_expired(self, session_name: str, create_timestamp: int): # 从数据库加载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session - config = context.get_config() + config = context.get_config_manager().data self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `last_interact_timestamp` > {} - """.format(int(time.time()) - config.session_expire_time)) + """.format(int(time.time()) - config['session_expire_time'])) results = self.cursor.fetchall() sessions = {} for result in results: diff --git a/pkg/openai/api/model.py b/pkg/openai/api/model.py index f2c04ccc..0a1f6a3a 100644 --- a/pkg/openai/api/model.py +++ b/pkg/openai/api/model.py @@ -3,6 +3,8 @@ import openai +from ...utils import context + class RequestBase: @@ -14,7 +16,6 @@ def __init__(self, *args, **kwargs): raise NotImplementedError def _next_key(self): - import pkg.utils.context as context switched, name = context.get_openai_manager().key_mgr.auto_switch() logging.debug("切换api-key: switched={}, name={}".format(switched, name)) self.client.api_key = context.get_openai_manager().key_mgr.get_using_key() @@ -22,12 +23,12 @@ def _next_key(self): def _req(self, **kwargs): """处理代理问题""" logging.debug("请求接口参数: %s", str(kwargs)) - import config + config = context.get_config_manager().data ret = self.req_func(**kwargs) logging.debug("接口请求返回:%s", str(ret)) - if config.switch_strategy == 'active': + if config['switch_strategy'] == 'active': self._next_key() return ret diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index ec1dbf2f..f6b03801 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -1,9 +1,10 @@ # 多情景预设值管理 import json import logging -import config import os +from ..utils import context + # __current__ = "default" # """当前默认使用的情景预设的名称 @@ -62,22 +63,24 @@ class NormalScenarioMode(ScenarioMode): """普通情景预设模式""" def __init__(self): + config = context.get_config_manager().data + # 加载config中的default_prompt值 - if type(config.default_prompt) == str: + if type(config['default_prompt']) == str: self.using_prompt_name = "default" self.prompts = {"default": [ { "role": "system", - "content": config.default_prompt + "content": config['default_prompt'] } ]} - elif type(config.default_prompt) == dict: - for key in config.default_prompt: + elif type(config['default_prompt']) == dict: + for key in config['default_prompt']: self.prompts[key] = [ { "role": "system", - "content": config.default_prompt[key] + "content": config['default_prompt'][key] } ] @@ -123,9 +126,9 @@ def register_all(): def mode_inst() -> ScenarioMode: """获取指定名称的情景预设模式对象""" - import config + config = context.get_config_manager().data - if config.preset_mode == "default": - config.preset_mode = "normal" + if config['preset_mode'] == "default": + config['preset_mode'] = "normal" - return scenario_mode_mapping[config.preset_mode] + return scenario_mode_mapping[config['preset_mode']] diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 385f4346..0fd008b9 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -43,13 +43,13 @@ def request_completion(self, messages: list): """请求补全接口回复= """ # 选择接口请求类 - config = context.get_config() + config = context.get_config_manager().data request: api_model.RequestBase - model: str = config.completion_api_params['model'] + model: str = config['completion_api_params']['model'] - cp_parmas = config.completion_api_params.copy() + cp_parmas = config['completion_api_params'].copy() del cp_parmas['model'] request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) @@ -74,8 +74,8 @@ def request_image(self, prompt) -> dict: Returns: dict: 响应 """ - config = context.get_config() - params = config.image_api_params + config = context.get_config_manager().data + params = config['image_api_params'] response = openai.Image.create( prompt=prompt, diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 498b9082..197ccfd0 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -36,11 +36,11 @@ def reset_session_prompt(session_name, prompt): f.write(prompt) f.close() # 生成新数据 - config = context.get_config() + config = context.get_config_manager().data prompt = [ { 'role': 'system', - 'content': config.default_prompt['default'] if type(config.default_prompt) == dict else config.default_prompt + 'content': config['default_prompt']['default'] if type(config['default_prompt']) == dict else config['default_prompt'] } ] # 警告 @@ -170,15 +170,15 @@ def expire_check_timer_loop(self, create_timestamp: int): if self.create_timestamp != create_timestamp or self not in sessions.values(): return - config = context.get_config() - if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: + config = context.get_config_manager().data + if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']: logging.info('session {} 已过期'.format(self.name)) # 触发插件事件 args = { 'session_name': self.name, 'session': self, - 'session_expire_time': config.session_expire_time + 'session_expire_time': config['session_expire_time'] } event = plugin_host.emit(plugin_models.SessionExpired, **args) if event.is_prevented_default(): @@ -216,8 +216,8 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: if event.is_prevented_default(): return None, None, None - config = context.get_config() - max_length = config.prompt_submit_length + config = context.get_config_manager().data + max_length = config['prompt_submit_length'] local_default_prompt = self.default_prompt.copy() local_prompt = self.prompt.copy() @@ -254,7 +254,7 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: funcs = [] - trace_func_calls = config.trace_function_calls + trace_func_calls = config['trace_function_calls'] botmgr = context.get_qqbot_manager() session_name_spt: list[str] = self.name.split("_") @@ -381,7 +381,7 @@ def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) # 包装目前的对话回合内容 changable_prompts = [] - use_model = context.get_config().completion_api_params['model'] + use_model = context.get_config_manager().data['completion_api_params']['model'] ptr = len(prompt) - 1 diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py index fd66a4bc..d8373cd8 100644 --- a/pkg/qqbot/blob.py +++ b/pkg/qqbot/blob.py @@ -9,7 +9,7 @@ from mirai.models.base import MiraiBaseModel from ..utils import text2img -import config +from ..utils import context class ForwardMessageDiaplay(MiraiBaseModel): @@ -64,13 +64,16 @@ def text_to_image(text: str) -> MessageComponent: def check_text(text: str) -> list: """检查文本是否为长消息,并转换成该使用的消息链组件""" - if len(text) > config.blob_message_threshold: + + config = context.get_config_manager().data + + if len(text) > config['blob_message_threshold']: # logging.info("长消息: {}".format(text)) - if config.blob_message_strategy == 'image': + if config['blob_message_strategy'] == 'image': # 转换成图片 return [text_to_image(text)] - elif config.blob_message_strategy == 'forward': + elif config['blob_message_strategy'] == 'forward': # 包装转发消息 display = ForwardMessageDiaplay( @@ -82,7 +85,7 @@ def check_text(text: str) -> list: ) node = ForwardMessageNode( - sender_id=config.mirai_http_api_config['qq'], + sender_id=config['mirai_http_api_config']['qq'], sender_name='bot', message_chain=MessageChain([text]) ) diff --git a/pkg/qqbot/cmds/funcs/draw.py b/pkg/qqbot/cmds/funcs/draw.py index b9af92e9..315c89e3 100644 --- a/pkg/qqbot/cmds/funcs/draw.py +++ b/pkg/qqbot/cmds/funcs/draw.py @@ -3,7 +3,7 @@ import mirai from .. import aamgr -import config +from ....utils import context @aamgr.AbstractCommandNode.register( @@ -30,8 +30,8 @@ def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: logging.debug("draw_image result:{}".format(res)) reply = [mirai.Image(url=res['data'][0]['url'])] - if not (hasattr(config, 'include_image_description') - and not config.include_image_description): + config = context.get_config_manager().data + if config['include_image_description']: reply.append(" ".join(ctx.params)) return True, reply diff --git a/pkg/qqbot/cmds/session/default.py b/pkg/qqbot/cmds/session/default.py index bb187123..6103d210 100644 --- a/pkg/qqbot/cmds/session/default.py +++ b/pkg/qqbot/cmds/session/default.py @@ -1,4 +1,6 @@ from .. import aamgr +from ....utils import context + @aamgr.AbstractCommandNode.register( parent=None, @@ -15,12 +17,13 @@ def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: session_name = ctx.session_name params = ctx.params reply = [] - import config + + config = context.get_config_manager().data if len(params) == 0: # 输出目前所有情景预设 import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config.preset_mode) + reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode']) prompts = dprompt.mode_inst().list() diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index 414ffee4..8e2fe40b 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -4,7 +4,7 @@ from ..qqbot.cmds import aamgr as cmdmgr -def process_command(session_name: str, text_message: str, mgr, config, +def process_command(session_name: str, text_message: str, mgr, config: dict, launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: reply = [] try: diff --git a/pkg/qqbot/filter.py b/pkg/qqbot/filter.py index d4bf579c..c3a58093 100644 --- a/pkg/qqbot/filter.py +++ b/pkg/qqbot/filter.py @@ -4,6 +4,8 @@ import json import logging +from ..utils import context + class ReplyFilter: sensitive_words = [] @@ -20,12 +22,13 @@ def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): self.sensitive_words = sensitive_words self.mask = mask self.mask_word = mask_word - import config - self.baidu_check = config.baidu_check - self.baidu_api_key = config.baidu_api_key - self.baidu_secret_key = config.baidu_secret_key - self.inappropriate_message_tips = config.inappropriate_message_tips + config = context.get_config_manager().data + + self.baidu_check = config['baidu_check'] + self.baidu_api_key = config['baidu_api_key'] + self.baidu_secret_key = config['baidu_secret_key'] + self.inappropriate_message_tips = config['inappropriate_message_tips'] def is_illegal(self, message: str) -> bool: processed = self.process(message) diff --git a/pkg/qqbot/ignore.py b/pkg/qqbot/ignore.py index 5269147d..e1adc777 100644 --- a/pkg/qqbot/ignore.py +++ b/pkg/qqbot/ignore.py @@ -1,16 +1,18 @@ import re +from ..utils import context + def ignore(msg: str) -> bool: """检查消息是否应该被忽略""" - import config + config = context.get_config_manager().data - if 'prefix' in config.ignore_rules: - for rule in config.ignore_rules['prefix']: + if 'prefix' in config['ignore_rules']: + for rule in config['ignore_rules']['prefix']: if msg.startswith(rule): return True - if 'regexp' in config.ignore_rules: - for rule in config.ignore_rules['regexp']: + if 'regexp' in config['ignore_rules']: + for rule in config['ignore_rules']['regexp']: if re.search(rule, msg): return True diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index c6058abd..beff6645 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -13,15 +13,15 @@ def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" - import config + config = context.get_config_manager().data context.get_qqbot_manager().notify_admin(notify_admin) - if config.hide_exce_info_to_user: + if config['hide_exce_info_to_user']: return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] else: return [set_reply] -def process_normal_message(text_message: str, mgr, config, launcher_type: str, +def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str, launcher_id: int, sender_id: int) -> list: session_name = f"{launcher_type}_{launcher_id}" logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( @@ -39,7 +39,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") break try: - prefix = "[GPT]" if config.show_prefix else "" + prefix = "[GPT]" if config['show_prefix'] else "" text, finish_reason, funcs = session.query(text_message) @@ -118,7 +118,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:RateLimitError,请重试或联系作者,或等待修复") except openai.BadRequestError as e: - if config.auto_reset and "This model's maximum context length is" in str(e): + if config['auto_reset'] and "This model's maximum context length is" in str(e): session.reset(persist=True) reply = [tips_custom.session_auto_reset_message] else: diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index ed45b37f..5062c0fd 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,6 +1,7 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio import time +import traceback import mirai import logging @@ -28,11 +29,11 @@ def is_admin(qq: int) -> bool: """兼容list和int类型的管理员判断""" - import config - if type(config.admin_qq) == list: - return qq in config.admin_qq + config = context.get_config_manager().data + if type(config['admin_qq']) == list: + return qq in config['admin_qq'] else: - return qq == config.admin_qq + return qq == config['admin_qq'] def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, @@ -53,9 +54,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes logging.info("根据忽略规则忽略消息: {}".format(text_message)) return [] - import config + config = context.get_config_manager().data - if not config.wait_last_done and session_name in processing: + if not config['wait_last_done'] and session_name in processing: return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) # 检查是否被禁言 @@ -65,8 +66,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) return reply - import config - if config.income_msg_check: + if config['income_msg_check']: if mgr.reply_filter.is_illegal(text_message): return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) @@ -81,8 +81,6 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes # 处理消息 try: - config = context.get_config() - processing.append(session_name) try: if text_message.startswith('!') or text_message.startswith("!"): # 指令 @@ -114,7 +112,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes else: # 消息 # 限速丢弃检查 # print(ratelimit.__crt_minute_usage__[session_name]) - if config.rate_limit_strategy == "drop": + if config['rate_limit_strategy'] == "drop": if ratelimit.is_reach_limit(session_name): logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) @@ -144,7 +142,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes mgr, config, launcher_type, launcher_id, sender_id) # 限速等待时间 - if config.rate_limit_strategy == "wait": + if config['rate_limit_strategy'] == "wait": time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) ratelimit.add_usage(session_name) @@ -167,13 +165,13 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes openai_session.get_session(session_name).release_response_lock() # 检查延迟时间 - if config.force_delay_range[1] == 0: + if config['force_delay_range'][1] == 0: delay_time = 0 else: import random # 从延迟范围中随机取一个值(浮点) - rdm = random.uniform(config.force_delay_range[0], config.force_delay_range[1]) + rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1]) spent = time.time() - start_time diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py index 1d718a74..e8dd2913 100644 --- a/pkg/qqbot/ratelimit.py +++ b/pkg/qqbot/ratelimit.py @@ -3,6 +3,9 @@ import logging import threading +from ..utils import context + + __crt_minute_usage__ = {} """当前分钟每个会话的对话次数""" @@ -12,16 +15,16 @@ def get_limitation(session_name: str) -> int: """获取会话的限制次数""" - import config + config = context.get_config_manager().data - if type(config.rate_limitation) == dict: + if type(config['rate_limitation']) == dict: # 如果被指定了 - if session_name in config.rate_limitation: - return config.rate_limitation[session_name] + if session_name in config['rate_limitation']: + return config['rate_limitation'][session_name] else: - return config.rate_limitation["default"] - elif type(config.rate_limitation) == int: - return config.rate_limitation + return config['rate_limitation']["default"] + elif type(config['rate_limitation']) == int: + return config['rate_limitation'] def add_usage(session_name: str): diff --git a/pkg/qqbot/sources/nakuru.py b/pkg/qqbot/sources/nakuru.py index 51e5e41b..d60baf5c 100644 --- a/pkg/qqbot/sources/nakuru.py +++ b/pkg/qqbot/sources/nakuru.py @@ -10,6 +10,7 @@ from .. import adapter as adapter_model from ...qqbot import blob +from ...utils import context class NakuruProjectMessageConverter(adapter_model.MessageConverter): @@ -172,12 +173,14 @@ def __init__(self, cfg: dict): self.listener_list = [] # nakuru库有bug,这个接口没法带access_token,会失败 # 所以目前自行发请求 - import config + + config = context.get_config_manager().data + import requests resp = requests.get( - url="http://{}:{}/get_login_info".format(config.nakuru_config['host'], config.nakuru_config['http_port']), + url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']), headers={ - 'Authorization': "Bearer " + config.nakuru_config['token'] if 'token' in config.nakuru_config else "" + 'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else "" }, timeout=5 ) @@ -270,7 +273,7 @@ def register_listener( logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback)) # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: self.event_converter.yiri2target(event_type)): + async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)): callback(self.event_converter.target2yiri(source)) # 将包装函数和原函数的对应关系存入列表 diff --git a/pkg/utils/log.py b/pkg/utils/log.py index b976583b..e45f97e6 100644 --- a/pkg/utils/log.py +++ b/pkg/utils/log.py @@ -3,6 +3,8 @@ import logging import shutil +from . import context + log_file_name = "qchatgpt.log" @@ -36,7 +38,6 @@ def init_runtime_log_file(): def reset_logging(): global log_file_name - import config import pkg.utils.context import colorlog @@ -46,7 +47,11 @@ def reset_logging(): for handler in logging.getLogger().handlers: logging.getLogger().removeHandler(handler) - logging.basicConfig(level=config.logging_level, # 设置日志输出格式 + config_mgr = context.get_config_manager() + + logging_level = logging.INFO if config_mgr is None else config_mgr.data['logging_level'] + + logging.basicConfig(level=logging_level, # 设置日志输出格式 filename=log_file_name, # log日志输出的文件位置和文件名 format="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", # 日志输出的格式 @@ -54,7 +59,7 @@ def reset_logging(): datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 ) sh = logging.StreamHandler() - sh.setLevel(config.logging_level) + sh.setLevel(logging_level) sh.setFormatter(colorlog.ColoredFormatter( fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " "%(message)s", diff --git a/pkg/utils/network.py b/pkg/utils/network.py index 72950658..a4498854 100644 --- a/pkg/utils/network.py +++ b/pkg/utils/network.py @@ -1,9 +1,11 @@ +from . import context + def wrapper_proxies() -> dict: """获取代理""" - import config + config = context.get_config_manager().data return { - "http": config.openai_config['proxy'], - "https": config.openai_config['proxy'] - } if 'proxy' in config.openai_config and (config.openai_config['proxy'] is not None) else None + "http": config['openai_config']['proxy'], + "https": config['openai_config']['proxy'] + } if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py index 1da2afca..4477ef2e 100644 --- a/pkg/utils/text2img.py +++ b/pkg/utils/text2img.py @@ -1,37 +1,42 @@ import logging import re import os -import config import traceback from PIL import Image, ImageDraw, ImageFont -text_render_font: ImageFont = None +from ..utils import context -if config.blob_message_strategy == "image": # 仅在启用了image时才加载字体 - use_font = config.font_path - try: - # 检查是否存在 - if not os.path.exists(use_font): - # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" - if not os.path.exists(use_font): - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config.blob_message_strategy = "forward" +text_render_font: ImageFont = None + +def initialize(): + config = context.get_config_manager().data + + if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体 + use_font = config['font_path'] + try: + + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + logging.info("使用Windows自带字体:" + use_font) + text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") else: - logging.info("使用Windows自带字体:" + use_font) - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") + logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" else: - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config.blob_message_strategy = "forward" - else: - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - except: - traceback.print_exc() - logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) - config.blob_message_strategy = "forward" + text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") + except: + traceback.print_exc() + logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + config['blob_message_strategy'] = "forward" def indexNumber(path=''): From f9d461d9a196984454e4acdefa297e157a6fcbc1 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Mon, 27 Nov 2023 00:00:22 +0800 Subject: [PATCH 10/10] =?UTF-8?q?feat:=20=E7=A7=BB=E9=99=A4=E8=BF=87?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E9=85=8D=E7=BD=AE=E6=A8=A1=E5=9D=97=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 49 ------------------------------------------------- 1 file changed, 49 deletions(-) diff --git a/main.py b/main.py index 402a76b6..cc2cf5a4 100644 --- a/main.py +++ b/main.py @@ -99,23 +99,6 @@ def ensure_dependencies(): known_exception_caught = False -def override_config(): - import config - # 检查override.json覆盖 - if os.path.exists("override.json") and use_override: - override_json = json.load(open("override.json", "r", encoding="utf-8")) - overrided = [] - for key in override_json: - if hasattr(config, key): - setattr(config, key, override_json[key]) - # logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) - overrided.append(key) - else: - logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) - if len(overrided) > 0: - logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) - - def override_config_manager(): config = pkg.utils.context.get_config_manager().data @@ -133,36 +116,6 @@ def override_config_manager(): logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) -# 临时函数,用于加载config和上下文,未来统一放在config类 -def load_config(): - logging.info("检查config模块完整性.") - # 完整性校验 - non_exist_keys = [] - - is_integrity = True - config_template = importlib.import_module('config-template') - config = importlib.import_module('config') - for key in dir(config_template): - if not key.startswith("__") and not hasattr(config, key): - setattr(config, key, getattr(config_template, key)) - # logging.warning("[{}]不存在".format(key)) - non_exist_keys.append(key) - is_integrity = False - - if not is_integrity: - logging.warning("以下配置字段不存在: {}".format(", ".join(non_exist_keys))) - - # 检查override.json覆盖 - override_config() - - if not is_integrity: - logging.warning("以上不存在的配置已被设为默认值,您可以依据config-template.py检查config.py,将在3秒后继续启动... ") - time.sleep(3) - - # 存进上下文 - pkg.utils.context.set_config(config) - - def complete_tips(): """根据tips-custom-template模块补全tips模块的属性""" non_exist_keys = [] @@ -192,8 +145,6 @@ async def start_process(first_time_init=False): import pkg.utils.context # 加载配置 - load_config() - cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile( 'config.py', 'config-template.py'