From 3e17bbb90faddd9360ad49166b999b6c5fb02d43 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 23:58:06 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=80=82=E9=85=8D=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E7=AE=A1=E7=90=86=E5=99=A8=E8=AF=BB=E5=8F=96=E6=96=B9?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 +++ pkg/audit/gatherer.py | 6 ++-- pkg/database/manager.py | 4 +-- pkg/openai/api/model.py | 7 +++-- pkg/openai/dprompt.py | 23 ++++++++------ pkg/openai/manager.py | 10 +++--- pkg/openai/session.py | 18 +++++------ pkg/qqbot/blob.py | 13 +++++--- pkg/qqbot/cmds/funcs/draw.py | 6 ++-- pkg/qqbot/cmds/session/default.py | 7 +++-- pkg/qqbot/command.py | 2 +- pkg/qqbot/filter.py | 13 +++++--- pkg/qqbot/ignore.py | 12 +++++--- pkg/qqbot/message.py | 10 +++--- pkg/qqbot/process.py | 26 ++++++++-------- pkg/qqbot/ratelimit.py | 17 ++++++----- pkg/qqbot/sources/nakuru.py | 11 ++++--- pkg/utils/log.py | 11 +++++-- pkg/utils/network.py | 10 +++--- pkg/utils/text2img.py | 51 +++++++++++++++++-------------- 20 files changed, 148 insertions(+), 113 deletions(-) diff --git a/main.py b/main.py index 55fea88b..402a76b6 100644 --- a/main.py +++ b/main.py @@ -218,6 +218,10 @@ async def start_process(first_time_init=False): except Exception as e: print("更新openai库失败:{}, 请忽略或自行更新".format(e)) + # 初始化文字转图片 + from pkg.utils import text2img + text2img.initialize() + known_exception_caught = False try: try: diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index acf6b368..2dc2c560 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -47,10 +47,10 @@ def report_to_server(self, subservice_name: str, count: int): def thread_func(): try: - config = context.get_config() - if not config.report_usage: + config = context.get_config_manager().data + if not config['report_usage']: return - res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter)) + res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config['msg_source_adapter'])) if res.status_code != 200 or res.text != "ok": logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text)) except: diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 15a00ae9..f410b418 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -144,11 +144,11 @@ def set_session_expired(self, session_name: str, create_timestamp: int): # 从数据库加载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session - config = context.get_config() + config = context.get_config_manager().data self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `last_interact_timestamp` > {} - """.format(int(time.time()) - config.session_expire_time)) + """.format(int(time.time()) - config['session_expire_time'])) results = self.cursor.fetchall() sessions = {} for result in results: diff --git a/pkg/openai/api/model.py b/pkg/openai/api/model.py index f2c04ccc..0a1f6a3a 100644 --- a/pkg/openai/api/model.py +++ b/pkg/openai/api/model.py @@ -3,6 +3,8 @@ import openai +from ...utils import context + class RequestBase: @@ -14,7 +16,6 @@ def __init__(self, *args, **kwargs): raise NotImplementedError def _next_key(self): - import pkg.utils.context as context switched, name = context.get_openai_manager().key_mgr.auto_switch() logging.debug("切换api-key: switched={}, name={}".format(switched, name)) self.client.api_key = context.get_openai_manager().key_mgr.get_using_key() @@ -22,12 +23,12 @@ def _next_key(self): def _req(self, **kwargs): """处理代理问题""" logging.debug("请求接口参数: %s", str(kwargs)) - import config + config = context.get_config_manager().data ret = self.req_func(**kwargs) logging.debug("接口请求返回:%s", str(ret)) - if config.switch_strategy == 'active': + if config['switch_strategy'] == 'active': self._next_key() return ret diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index ec1dbf2f..f6b03801 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -1,9 +1,10 @@ # 多情景预设值管理 import json import logging -import config import os +from ..utils import context + # __current__ = "default" # """当前默认使用的情景预设的名称 @@ -62,22 +63,24 @@ class NormalScenarioMode(ScenarioMode): """普通情景预设模式""" def __init__(self): + config = context.get_config_manager().data + # 加载config中的default_prompt值 - if type(config.default_prompt) == str: + if type(config['default_prompt']) == str: self.using_prompt_name = "default" self.prompts = {"default": [ { "role": "system", - "content": config.default_prompt + "content": config['default_prompt'] } ]} - elif type(config.default_prompt) == dict: - for key in config.default_prompt: + elif type(config['default_prompt']) == dict: + for key in config['default_prompt']: self.prompts[key] = [ { "role": "system", - "content": config.default_prompt[key] + "content": config['default_prompt'][key] } ] @@ -123,9 +126,9 @@ def register_all(): def mode_inst() -> ScenarioMode: """获取指定名称的情景预设模式对象""" - import config + config = context.get_config_manager().data - if config.preset_mode == "default": - config.preset_mode = "normal" + if config['preset_mode'] == "default": + config['preset_mode'] = "normal" - return scenario_mode_mapping[config.preset_mode] + return scenario_mode_mapping[config['preset_mode']] diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 385f4346..0fd008b9 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -43,13 +43,13 @@ def request_completion(self, messages: list): """请求补全接口回复= """ # 选择接口请求类 - config = context.get_config() + config = context.get_config_manager().data request: api_model.RequestBase - model: str = config.completion_api_params['model'] + model: str = config['completion_api_params']['model'] - cp_parmas = config.completion_api_params.copy() + cp_parmas = config['completion_api_params'].copy() del cp_parmas['model'] request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) @@ -74,8 +74,8 @@ def request_image(self, prompt) -> dict: Returns: dict: 响应 """ - config = context.get_config() - params = config.image_api_params + config = context.get_config_manager().data + params = config['image_api_params'] response = openai.Image.create( prompt=prompt, diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 498b9082..197ccfd0 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -36,11 +36,11 @@ def reset_session_prompt(session_name, prompt): f.write(prompt) f.close() # 生成新数据 - config = context.get_config() + config = context.get_config_manager().data prompt = [ { 'role': 'system', - 'content': config.default_prompt['default'] if type(config.default_prompt) == dict else config.default_prompt + 'content': config['default_prompt']['default'] if type(config['default_prompt']) == dict else config['default_prompt'] } ] # 警告 @@ -170,15 +170,15 @@ def expire_check_timer_loop(self, create_timestamp: int): if self.create_timestamp != create_timestamp or self not in sessions.values(): return - config = context.get_config() - if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: + config = context.get_config_manager().data + if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']: logging.info('session {} 已过期'.format(self.name)) # 触发插件事件 args = { 'session_name': self.name, 'session': self, - 'session_expire_time': config.session_expire_time + 'session_expire_time': config['session_expire_time'] } event = plugin_host.emit(plugin_models.SessionExpired, **args) if event.is_prevented_default(): @@ -216,8 +216,8 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: if event.is_prevented_default(): return None, None, None - config = context.get_config() - max_length = config.prompt_submit_length + config = context.get_config_manager().data + max_length = config['prompt_submit_length'] local_default_prompt = self.default_prompt.copy() local_prompt = self.prompt.copy() @@ -254,7 +254,7 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: funcs = [] - trace_func_calls = config.trace_function_calls + trace_func_calls = config['trace_function_calls'] botmgr = context.get_qqbot_manager() session_name_spt: list[str] = self.name.split("_") @@ -381,7 +381,7 @@ def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) # 包装目前的对话回合内容 changable_prompts = [] - use_model = context.get_config().completion_api_params['model'] + use_model = context.get_config_manager().data['completion_api_params']['model'] ptr = len(prompt) - 1 diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py index fd66a4bc..d8373cd8 100644 --- a/pkg/qqbot/blob.py +++ b/pkg/qqbot/blob.py @@ -9,7 +9,7 @@ from mirai.models.base import MiraiBaseModel from ..utils import text2img -import config +from ..utils import context class ForwardMessageDiaplay(MiraiBaseModel): @@ -64,13 +64,16 @@ def text_to_image(text: str) -> MessageComponent: def check_text(text: str) -> list: """检查文本是否为长消息,并转换成该使用的消息链组件""" - if len(text) > config.blob_message_threshold: + + config = context.get_config_manager().data + + if len(text) > config['blob_message_threshold']: # logging.info("长消息: {}".format(text)) - if config.blob_message_strategy == 'image': + if config['blob_message_strategy'] == 'image': # 转换成图片 return [text_to_image(text)] - elif config.blob_message_strategy == 'forward': + elif config['blob_message_strategy'] == 'forward': # 包装转发消息 display = ForwardMessageDiaplay( @@ -82,7 +85,7 @@ def check_text(text: str) -> list: ) node = ForwardMessageNode( - sender_id=config.mirai_http_api_config['qq'], + sender_id=config['mirai_http_api_config']['qq'], sender_name='bot', message_chain=MessageChain([text]) ) diff --git a/pkg/qqbot/cmds/funcs/draw.py b/pkg/qqbot/cmds/funcs/draw.py index b9af92e9..315c89e3 100644 --- a/pkg/qqbot/cmds/funcs/draw.py +++ b/pkg/qqbot/cmds/funcs/draw.py @@ -3,7 +3,7 @@ import mirai from .. import aamgr -import config +from ....utils import context @aamgr.AbstractCommandNode.register( @@ -30,8 +30,8 @@ def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: logging.debug("draw_image result:{}".format(res)) reply = [mirai.Image(url=res['data'][0]['url'])] - if not (hasattr(config, 'include_image_description') - and not config.include_image_description): + config = context.get_config_manager().data + if config['include_image_description']: reply.append(" ".join(ctx.params)) return True, reply diff --git a/pkg/qqbot/cmds/session/default.py b/pkg/qqbot/cmds/session/default.py index bb187123..6103d210 100644 --- a/pkg/qqbot/cmds/session/default.py +++ b/pkg/qqbot/cmds/session/default.py @@ -1,4 +1,6 @@ from .. import aamgr +from ....utils import context + @aamgr.AbstractCommandNode.register( parent=None, @@ -15,12 +17,13 @@ def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: session_name = ctx.session_name params = ctx.params reply = [] - import config + + config = context.get_config_manager().data if len(params) == 0: # 输出目前所有情景预设 import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config.preset_mode) + reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode']) prompts = dprompt.mode_inst().list() diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index 414ffee4..8e2fe40b 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -4,7 +4,7 @@ from ..qqbot.cmds import aamgr as cmdmgr -def process_command(session_name: str, text_message: str, mgr, config, +def process_command(session_name: str, text_message: str, mgr, config: dict, launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: reply = [] try: diff --git a/pkg/qqbot/filter.py b/pkg/qqbot/filter.py index d4bf579c..c3a58093 100644 --- a/pkg/qqbot/filter.py +++ b/pkg/qqbot/filter.py @@ -4,6 +4,8 @@ import json import logging +from ..utils import context + class ReplyFilter: sensitive_words = [] @@ -20,12 +22,13 @@ def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): self.sensitive_words = sensitive_words self.mask = mask self.mask_word = mask_word - import config - self.baidu_check = config.baidu_check - self.baidu_api_key = config.baidu_api_key - self.baidu_secret_key = config.baidu_secret_key - self.inappropriate_message_tips = config.inappropriate_message_tips + config = context.get_config_manager().data + + self.baidu_check = config['baidu_check'] + self.baidu_api_key = config['baidu_api_key'] + self.baidu_secret_key = config['baidu_secret_key'] + self.inappropriate_message_tips = config['inappropriate_message_tips'] def is_illegal(self, message: str) -> bool: processed = self.process(message) diff --git a/pkg/qqbot/ignore.py b/pkg/qqbot/ignore.py index 5269147d..e1adc777 100644 --- a/pkg/qqbot/ignore.py +++ b/pkg/qqbot/ignore.py @@ -1,16 +1,18 @@ import re +from ..utils import context + def ignore(msg: str) -> bool: """检查消息是否应该被忽略""" - import config + config = context.get_config_manager().data - if 'prefix' in config.ignore_rules: - for rule in config.ignore_rules['prefix']: + if 'prefix' in config['ignore_rules']: + for rule in config['ignore_rules']['prefix']: if msg.startswith(rule): return True - if 'regexp' in config.ignore_rules: - for rule in config.ignore_rules['regexp']: + if 'regexp' in config['ignore_rules']: + for rule in config['ignore_rules']['regexp']: if re.search(rule, msg): return True diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index c6058abd..beff6645 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -13,15 +13,15 @@ def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" - import config + config = context.get_config_manager().data context.get_qqbot_manager().notify_admin(notify_admin) - if config.hide_exce_info_to_user: + if config['hide_exce_info_to_user']: return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] else: return [set_reply] -def process_normal_message(text_message: str, mgr, config, launcher_type: str, +def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str, launcher_id: int, sender_id: int) -> list: session_name = f"{launcher_type}_{launcher_id}" logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( @@ -39,7 +39,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") break try: - prefix = "[GPT]" if config.show_prefix else "" + prefix = "[GPT]" if config['show_prefix'] else "" text, finish_reason, funcs = session.query(text_message) @@ -118,7 +118,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:RateLimitError,请重试或联系作者,或等待修复") except openai.BadRequestError as e: - if config.auto_reset and "This model's maximum context length is" in str(e): + if config['auto_reset'] and "This model's maximum context length is" in str(e): session.reset(persist=True) reply = [tips_custom.session_auto_reset_message] else: diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index ed45b37f..5062c0fd 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,6 +1,7 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio import time +import traceback import mirai import logging @@ -28,11 +29,11 @@ def is_admin(qq: int) -> bool: """兼容list和int类型的管理员判断""" - import config - if type(config.admin_qq) == list: - return qq in config.admin_qq + config = context.get_config_manager().data + if type(config['admin_qq']) == list: + return qq in config['admin_qq'] else: - return qq == config.admin_qq + return qq == config['admin_qq'] def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, @@ -53,9 +54,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes logging.info("根据忽略规则忽略消息: {}".format(text_message)) return [] - import config + config = context.get_config_manager().data - if not config.wait_last_done and session_name in processing: + if not config['wait_last_done'] and session_name in processing: return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) # 检查是否被禁言 @@ -65,8 +66,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) return reply - import config - if config.income_msg_check: + if config['income_msg_check']: if mgr.reply_filter.is_illegal(text_message): return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) @@ -81,8 +81,6 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes # 处理消息 try: - config = context.get_config() - processing.append(session_name) try: if text_message.startswith('!') or text_message.startswith("!"): # 指令 @@ -114,7 +112,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes else: # 消息 # 限速丢弃检查 # print(ratelimit.__crt_minute_usage__[session_name]) - if config.rate_limit_strategy == "drop": + if config['rate_limit_strategy'] == "drop": if ratelimit.is_reach_limit(session_name): logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) @@ -144,7 +142,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes mgr, config, launcher_type, launcher_id, sender_id) # 限速等待时间 - if config.rate_limit_strategy == "wait": + if config['rate_limit_strategy'] == "wait": time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) ratelimit.add_usage(session_name) @@ -167,13 +165,13 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes openai_session.get_session(session_name).release_response_lock() # 检查延迟时间 - if config.force_delay_range[1] == 0: + if config['force_delay_range'][1] == 0: delay_time = 0 else: import random # 从延迟范围中随机取一个值(浮点) - rdm = random.uniform(config.force_delay_range[0], config.force_delay_range[1]) + rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1]) spent = time.time() - start_time diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py index 1d718a74..e8dd2913 100644 --- a/pkg/qqbot/ratelimit.py +++ b/pkg/qqbot/ratelimit.py @@ -3,6 +3,9 @@ import logging import threading +from ..utils import context + + __crt_minute_usage__ = {} """当前分钟每个会话的对话次数""" @@ -12,16 +15,16 @@ def get_limitation(session_name: str) -> int: """获取会话的限制次数""" - import config + config = context.get_config_manager().data - if type(config.rate_limitation) == dict: + if type(config['rate_limitation']) == dict: # 如果被指定了 - if session_name in config.rate_limitation: - return config.rate_limitation[session_name] + if session_name in config['rate_limitation']: + return config['rate_limitation'][session_name] else: - return config.rate_limitation["default"] - elif type(config.rate_limitation) == int: - return config.rate_limitation + return config['rate_limitation']["default"] + elif type(config['rate_limitation']) == int: + return config['rate_limitation'] def add_usage(session_name: str): diff --git a/pkg/qqbot/sources/nakuru.py b/pkg/qqbot/sources/nakuru.py index 51e5e41b..d60baf5c 100644 --- a/pkg/qqbot/sources/nakuru.py +++ b/pkg/qqbot/sources/nakuru.py @@ -10,6 +10,7 @@ from .. import adapter as adapter_model from ...qqbot import blob +from ...utils import context class NakuruProjectMessageConverter(adapter_model.MessageConverter): @@ -172,12 +173,14 @@ def __init__(self, cfg: dict): self.listener_list = [] # nakuru库有bug,这个接口没法带access_token,会失败 # 所以目前自行发请求 - import config + + config = context.get_config_manager().data + import requests resp = requests.get( - url="http://{}:{}/get_login_info".format(config.nakuru_config['host'], config.nakuru_config['http_port']), + url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']), headers={ - 'Authorization': "Bearer " + config.nakuru_config['token'] if 'token' in config.nakuru_config else "" + 'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else "" }, timeout=5 ) @@ -270,7 +273,7 @@ def register_listener( logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback)) # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: self.event_converter.yiri2target(event_type)): + async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)): callback(self.event_converter.target2yiri(source)) # 将包装函数和原函数的对应关系存入列表 diff --git a/pkg/utils/log.py b/pkg/utils/log.py index b976583b..e45f97e6 100644 --- a/pkg/utils/log.py +++ b/pkg/utils/log.py @@ -3,6 +3,8 @@ import logging import shutil +from . import context + log_file_name = "qchatgpt.log" @@ -36,7 +38,6 @@ def init_runtime_log_file(): def reset_logging(): global log_file_name - import config import pkg.utils.context import colorlog @@ -46,7 +47,11 @@ def reset_logging(): for handler in logging.getLogger().handlers: logging.getLogger().removeHandler(handler) - logging.basicConfig(level=config.logging_level, # 设置日志输出格式 + config_mgr = context.get_config_manager() + + logging_level = logging.INFO if config_mgr is None else config_mgr.data['logging_level'] + + logging.basicConfig(level=logging_level, # 设置日志输出格式 filename=log_file_name, # log日志输出的文件位置和文件名 format="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", # 日志输出的格式 @@ -54,7 +59,7 @@ def reset_logging(): datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 ) sh = logging.StreamHandler() - sh.setLevel(config.logging_level) + sh.setLevel(logging_level) sh.setFormatter(colorlog.ColoredFormatter( fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " "%(message)s", diff --git a/pkg/utils/network.py b/pkg/utils/network.py index 72950658..a4498854 100644 --- a/pkg/utils/network.py +++ b/pkg/utils/network.py @@ -1,9 +1,11 @@ +from . import context + def wrapper_proxies() -> dict: """获取代理""" - import config + config = context.get_config_manager().data return { - "http": config.openai_config['proxy'], - "https": config.openai_config['proxy'] - } if 'proxy' in config.openai_config and (config.openai_config['proxy'] is not None) else None + "http": config['openai_config']['proxy'], + "https": config['openai_config']['proxy'] + } if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py index 1da2afca..4477ef2e 100644 --- a/pkg/utils/text2img.py +++ b/pkg/utils/text2img.py @@ -1,37 +1,42 @@ import logging import re import os -import config import traceback from PIL import Image, ImageDraw, ImageFont -text_render_font: ImageFont = None +from ..utils import context -if config.blob_message_strategy == "image": # 仅在启用了image时才加载字体 - use_font = config.font_path - try: - # 检查是否存在 - if not os.path.exists(use_font): - # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" - if not os.path.exists(use_font): - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config.blob_message_strategy = "forward" +text_render_font: ImageFont = None + +def initialize(): + config = context.get_config_manager().data + + if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体 + use_font = config['font_path'] + try: + + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + logging.info("使用Windows自带字体:" + use_font) + text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") else: - logging.info("使用Windows自带字体:" + use_font) - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") + logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" else: - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config.blob_message_strategy = "forward" - else: - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - except: - traceback.print_exc() - logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) - config.blob_message_strategy = "forward" + text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") + except: + traceback.print_exc() + logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + config['blob_message_strategy'] = "forward" def indexNumber(path=''):