Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: 使用 配置管理器 统一管理配置文件 #618

Merged
merged 10 commits into from
Nov 26, 2023
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
config.py
/config.py
.idea/
__pycache__/
database.db
Expand Down
13 changes: 0 additions & 13 deletions config-template.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,6 @@
# 设置为False时,向用户及管理员发送错误详细信息
hide_exce_info_to_user = False

# 线程池相关配置
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
# 如果你不清楚该参数的意义,请不要更改
# 程序运行本身线程池,无代码层面修改请勿更改
sys_pool_num = 8

# 执行管理员请求和指令的线程池并行线程数量,一般和管理员数量相等
admin_pool_num = 4

# 执行用户请求和指令的线程池并行线程数量
# 如需要更高的并发,可以增大该值
user_pool_num = 8

# 每个会话的过期时间,单位为秒
# 默认值20分钟
session_expire_time = 1200
Expand Down
136 changes: 63 additions & 73 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import logging
import sys
import traceback
import asyncio

sys.path.append(".")

from pkg.utils.log import init_runtime_log_file, reset_logging
from pkg.config import manager as config_mgr
from pkg.config.impls import pymodule as pymodule_cfg


def check_file():
Expand Down Expand Up @@ -96,15 +99,15 @@ def ensure_dependencies():
known_exception_caught = False


def override_config():
import config
# 检查override.json覆盖
def override_config_manager():
config = pkg.utils.context.get_config_manager().data

if os.path.exists("override.json") and use_override:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
for key in override_json:
if hasattr(config, key):
setattr(config, key, override_json[key])
if key in config:
config[key] = override_json[key]
# logging.info("覆写配置[{}]为[{}]".format(key, override_json[key]))
overrided.append(key)
else:
Expand All @@ -113,36 +116,6 @@ def override_config():
logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided)))


# 临时函数,用于加载config和上下文,未来统一放在config类
def load_config():
logging.info("检查config模块完整性.")
# 完整性校验
non_exist_keys = []

is_integrity = True
config_template = importlib.import_module('config-template')
config = importlib.import_module('config')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
# logging.warning("[{}]不存在".format(key))
non_exist_keys.append(key)
is_integrity = False

if not is_integrity:
logging.warning("以下配置字段不存在: {}".format(", ".join(non_exist_keys)))

# 检查override.json覆盖
override_config()

if not is_integrity:
logging.warning("以上不存在的配置已被设为默认值,您可以依据config-template.py检查config.py,将在3秒后继续启动... ")
time.sleep(3)

# 存进上下文
pkg.utils.context.set_config(config)


def complete_tips():
"""根据tips-custom-template模块补全tips模块的属性"""
non_exist_keys = []
Expand All @@ -165,17 +138,29 @@ def complete_tips():
time.sleep(3)


def start(first_time_init=False):
async def start_process(first_time_init=False):
"""启动流程,reload之后会被执行"""

global known_exception_caught
import pkg.utils.context

config = pkg.utils.context.get_config()
# 加载配置
cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile(
'config.py',
'config-template.py'
)
await config_mgr.ConfigManager(cfg_inst).load_config()

override_config_manager()

# 检查tips模块
complete_tips()

cfg = pkg.utils.context.get_config_manager().data
# 更新openai库到最新版本
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']:
print("正在更新依赖库,请等待...")
if not hasattr(config, 'upgrade_dependencies'):
if 'upgrade_dependencies' not in cfg:
print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False")
else:
print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False")
Expand All @@ -184,6 +169,10 @@ def start(first_time_init=False):
except Exception as e:
print("更新openai库失败:{}, 请忽略或自行更新".format(e))

# 初始化文字转图片
from pkg.utils import text2img
text2img.initialize()

known_exception_caught = False
try:
try:
Expand All @@ -192,19 +181,19 @@ def start(first_time_init=False):
pkg.utils.context.context['logger_handler'] = sh

# 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
if cfg['admin_qq'] == 0:
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
while True:
try:
config.admin_qq = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: "))
cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: "))
# 写入到文件

# 读取文件
config_file_str = ""
with open("config.py", "r", encoding="utf-8") as f:
config_file_str = f.read()
# 替换
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(config.admin_qq))
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq']))
# 写入
with open("config.py", "w", encoding="utf-8") as f:
f.write(config_file_str)
Expand Down Expand Up @@ -233,23 +222,23 @@ def start(first_time_init=False):
# 配置OpenAI proxy
import openai
openai.proxies = None # 先重置,因为重载后可能需要清除proxy
if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None:
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
openai.proxies = {
"http": config.openai_config["http_proxy"],
"https": config.openai_config["http_proxy"]
"http": cfg['openai_config']["http_proxy"],
"https": cfg['openai_config']["http_proxy"]
}

# 配置openai api_base
if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+config.openai_config['reverse_proxy'])
openai.base_url = config.openai_config["reverse_proxy"]
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
openai.base_url = cfg['openai_config']["reverse_proxy"]

# 主启动流程
database = pkg.database.manager.DatabaseManager()

database.initialize_database()

openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'])
openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key'])

# 加载所有未超时的session
pkg.openai.session.load_sessions()
Expand Down Expand Up @@ -338,21 +327,20 @@ def run_bot_wrapper():

if first_time_init:
if not known_exception_caught:
import config
if config.msg_source_adapter == "yirimirai":
logging.info("QQ: {}, MAH: {}".format(config.mirai_http_api_config['qq'], config.mirai_http_api_config['host']+":"+str(config.mirai_http_api_config['port'])))
if cfg['msg_source_adapter'] == "yirimirai":
logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port'])))
logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): '
'https://github.com/RockChinQ/QChatGPT/issues/37')
elif config.msg_source_adapter == 'nakuru':
logging.info("host: {}, port: {}, http_port: {}".format(config.nakuru_config['host'], config.nakuru_config['port'], config.nakuru_config['http_port']))
elif cfg['msg_source_adapter'] == 'nakuru':
logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port']))
logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确')
else:
sys.exit(1)
else:
logging.info('热重载完成')

# 发送赞赏码
if config.encourage_sponsor_at_start \
if cfg['encourage_sponsor_at_start'] \
and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048:

logging.info("发送赞赏码")
Expand Down Expand Up @@ -420,19 +408,12 @@ def main():
init_runtime_log_file()
pkg.utils.context.context['logger_handler'] = reset_logging()

# 加载配置
load_config()
config = pkg.utils.context.get_config()

# 检查tips模块
complete_tips()

# 配置线程池
from pkg.utils import ThreadCtl
thread_ctl = ThreadCtl(
sys_pool_num=config.sys_pool_num,
admin_pool_num=config.admin_pool_num,
user_pool_num=config.user_pool_num
sys_pool_num=8,
admin_pool_num=4,
user_pool_num=8
)
# 存进上下文
pkg.utils.context.set_thread_ctl(thread_ctl)
Expand All @@ -451,9 +432,11 @@ def main():
# 关闭urllib的http警告
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

def run_wrapper():
asyncio.run(start_process(True))

pkg.utils.context.get_thread_ctl().submit_sys_task(
start,
True
run_wrapper
)

# 主线程循环
Expand All @@ -463,12 +446,19 @@ def main():
except:
stop()
pkg.utils.context.get_thread_ctl().shutdown()
import platform
if platform.system() == 'Windows':
cmd = "taskkill /F /PID {}".format(os.getpid())
elif platform.system() in ['Linux', 'Darwin']:
cmd = "kill -9 {}".format(os.getpid())
os.system(cmd)

launch_args = sys.argv.copy()

if "--cov-report" not in launch_args:
import platform
if platform.system() == 'Windows':
cmd = "taskkill /F /PID {}".format(os.getpid())
elif platform.system() in ['Linux', 'Darwin']:
cmd = "kill -9 {}".format(os.getpid())
os.system(cmd)
else:
print("正常退出以生成覆盖率报告")
sys.exit(0)


if __name__ == '__main__':
Expand Down
3 changes: 0 additions & 3 deletions override-all.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@
"font_path": "",
"retry_times": 3,
"hide_exce_info_to_user": false,
"sys_pool_num": 8,
"admin_pool_num": 4,
"user_pool_num": 8,
"session_expire_time": 1200,
"rate_limitation": {
"default": 60
Expand Down
6 changes: 3 additions & 3 deletions pkg/audit/gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def report_to_server(self, subservice_name: str, count: int):
def thread_func():

try:
config = context.get_config()
if not config.report_usage:
config = context.get_config_manager().data
if not config['report_usage']:
return
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter))
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config['msg_source_adapter']))
if res.status_code != 200 or res.text != "ok":
logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text))
except:
Expand Down
Empty file added pkg/config/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions pkg/config/impls/pymodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import shutil
import importlib
import logging

from .. import model as file_model


class PythonModuleConfigFile(file_model.ConfigFile):
"""Python模块配置文件"""

config_file_name: str = None
"""配置文件名"""

template_file_name: str = None
"""模板文件名"""

def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name

def exists(self) -> bool:
return os.path.exists(self.config_file_name)

async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)

async def load(self) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name)

cfg = {}

allowed_types = (int, float, str, bool, list, dict)

for key in dir(module):
if key.startswith('__'):
continue

if not isinstance(getattr(module, key), allowed_types):
continue

cfg[key] = getattr(module, key)

# 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)

for key in dir(module):
if key.startswith('__'):
continue

if not isinstance(getattr(module, key), allowed_types):
continue

if key not in cfg:
cfg[key] = getattr(module, key)

return cfg

async def save(self, data: dict):
logging.warning('Python模块配置文件不支持保存')
23 changes: 23 additions & 0 deletions pkg/config/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from . import model as file_model
from ..utils import context


class ConfigManager:
"""配置文件管理器"""

file: file_model.ConfigFile = None
"""配置文件实例"""

data: dict = None
"""配置数据"""

def __init__(self, cfg_file: file_model.ConfigFile) -> None:
self.file = cfg_file
self.data = {}
context.set_config_manager(self)

async def load_config(self):
self.data = await self.file.load()

async def dump_config(self):
await self.file.save(self.data)
Loading
Loading