From a82bfa8a56b9e7d5317b183c65ff9cc1a8552765 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Fri, 8 Mar 2024 11:38:26 +0000 Subject: [PATCH 1/9] =?UTF-8?q?perf:=20=E4=B8=BA=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=E8=A3=85=E9=A5=B0=E5=99=A8=E6=B7=BB=E5=8A=A0=E6=96=AD=E8=A8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index a666f2c3..c5b0615b 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -20,6 +20,8 @@ def operator_class( parent_class: typing.Type[CommandOperator] = None ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: + assert issubclass(cls, CommandOperator) + cls.name = name cls.alias = alias cls.help = help From 7f554fd8625fc75037c8203d955c30d1c4f6948e Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 8 Mar 2024 19:56:57 +0800 Subject: [PATCH 2/9] =?UTF-8?q?feat:=20command=E6=94=AF=E6=8C=81=E6=89=A9?= =?UTF-8?q?=E5=B1=95=E5=91=BD=E4=BB=A4=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operator.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index c5b0615b..641a8cf5 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -8,17 +8,31 @@ preregistered_operators: list[typing.Type[CommandOperator]] = [] -"""预注册算子列表。在初始化时,所有算子类会被注册到此列表中。""" +"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。""" def operator_class( name: str, - help: str, + help: str = "", usage: str = None, alias: list[str] = [], privilege: int=1, # 1为普通用户,2为管理员 parent_class: typing.Type[CommandOperator] = None ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: + """命令类装饰器 + + Args: + name (str): 名称 + help (str, optional): 帮助信息. Defaults to "". + usage (str, optional): 使用说明. Defaults to None. + alias (list[str], optional): 别名. Defaults to []. + privilege (int, optional): 权限,1为普通用户可用,2为仅管理员可用. Defaults to 1. + parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None. + + Returns: + typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 注册后的命令类 + """ + def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: assert issubclass(cls, CommandOperator) From 22cb8a6a0678bf441eba9cdb6dd8f59a6b097a4c Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 8 Mar 2024 20:22:06 +0800 Subject: [PATCH 3/9] =?UTF-8?q?feat:=20=E5=86=85=E5=AE=B9=E8=BF=87?= =?UTF-8?q?=E6=BB=A4=E5=99=A8=E7=9A=84=E5=8F=AF=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operator.py | 2 +- pkg/pipeline/cntfilter/cntfilter.py | 21 +++++++++---- pkg/pipeline/cntfilter/filter.py | 30 +++++++++++++++++++ .../cntfilter/filters/baiduexamine.py | 1 + pkg/pipeline/cntfilter/filters/banwords.py | 1 + pkg/pipeline/cntfilter/filters/cntignore.py | 1 + pkg/platform/manager.py | 19 ------------ pkg/platform/sources/nakuru.py | 2 ++ pkg/platform/sources/qqbotpy.py | 2 ++ 9 files changed, 53 insertions(+), 26 deletions(-) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 641a8cf5..307e9fbe 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -30,7 +30,7 @@ def operator_class( parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None. Returns: - typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 注册后的命令类 + typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器 """ def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 92157bdd..5e2aa4d2 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -7,7 +7,7 @@ from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...config import manager as cfg_mgr -from . import filter, entities as filter_entities +from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine @@ -16,20 +16,29 @@ class ContentFilterStage(stage.PipelineStage): """内容过滤阶段""" - filter_chain: list[filter.ContentFilter] + filter_chain: list[filter_model.ContentFilter] def __init__(self, ap: app.Application): self.filter_chain = [] super().__init__(ap) async def initialize(self): - self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + filters_required = [ + "ContentIgnore" + ] if self.ap.pipeline_cfg.data['check-sensitive-words']: - self.filter_chain.append(banwords.BanWordFilter(self.ap)) - + filters_required.append("BanWordFilter") + if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + filters_required.append("BaiduCloudExamine") + + for filter in filter_model.preregistered_filters: + if filter.name in filters_required: + self.filter_chain.append( + filter(self.ap) + ) for filter in self.filter_chain: await filter.initialize() diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 57792145..23471392 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -1,12 +1,42 @@ # 内容过滤器的抽象类 from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_filters: list[typing.Type[ContentFilter]] = [] + + +def filter_class( + name: str +) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: + """内容过滤器类装饰器 + + Args: + name (str): 过滤器名称 + + Returns: + typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器 + """ + def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]: + assert issubclass(cls, ContentFilter) + + cls.name = name + + preregistered_filters.append(cls) + + return cls + + return decorator + + class ContentFilter(metaclass=abc.ABCMeta): + """内容过滤器抽象类""" + + name: str ap: app.Application diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index f72fe960..faa4bb6b 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -10,6 +10,7 @@ BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" +@filter_model.filter_class("BaiduCloudExamine") class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 587f81c3..c94374c8 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,6 +6,7 @@ from ....config import manager as cfg_mgr +@filter_model.filter_class("BanWordFilter") class BanWordFilter(filter_model.ContentFilter): """根据内容禁言""" diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 92fe94e8..baafeef0 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -5,6 +5,7 @@ from .. import filter as filter_model +@filter_model.filter_class("ContentIgnore") class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 3d73c198..7b40f2ab 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -163,25 +163,6 @@ async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_ quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False ) - # 通知系统管理员 - # TODO delete - # async def notify_admin(self, message: str): - # await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) - - # async def notify_admin_message_chain(self, message: mirai.MessageChain): - # if self.ap.system_cfg.data['admin-sessions'] != []: - - # admin_list = [] - # for admin in self.ap.system_cfg.data['admin-sessions']: - # admin_list.append(admin) - - # for adm in admin_list: - # self.adapter.send_message( - # adm.split("_")[0], - # adm.split("_")[1], - # message - # ) - async def run(self): try: tasks = [] diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 0a419a06..0b3b8c09 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -24,6 +24,8 @@ def yiri2target(message_chain: mirai.MessageChain) -> list: msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 6d74d0ea..313249a0 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -89,6 +89,8 @@ def yiri2target(message_chain: mirai.MessageChain): msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(text=message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) From 4347ddd42ac6417237037fe98c05f2a20d8774f9 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 8 Mar 2024 20:31:22 +0800 Subject: [PATCH 4/9] =?UTF-8?q?feat:=20=E9=95=BF=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=A4=84=E7=90=86=E7=AD=96=E7=95=A5=E5=8F=AF=E6=89=A9=E5=B1=95?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/cntfilter/cntfilter.py | 6 ++--- .../cntfilter/filters/baiduexamine.py | 2 +- pkg/pipeline/cntfilter/filters/banwords.py | 2 +- pkg/pipeline/cntfilter/filters/cntignore.py | 2 +- pkg/pipeline/longtext/longtext.py | 13 +++++++---- pkg/pipeline/longtext/strategies/forward.py | 1 + pkg/pipeline/longtext/strategies/image.py | 1 + pkg/pipeline/longtext/strategy.py | 23 +++++++++++++++++++ 8 files changed, 39 insertions(+), 11 deletions(-) diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 5e2aa4d2..fee2cd3f 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -25,14 +25,14 @@ def __init__(self, ap: app.Application): async def initialize(self): filters_required = [ - "ContentIgnore" + "content-filter" ] if self.ap.pipeline_cfg.data['check-sensitive-words']: - filters_required.append("BanWordFilter") + filters_required.append("ban-word-filter") if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - filters_required.append("BaiduCloudExamine") + filters_required.append("baidu-cloud-examine") for filter in filter_model.preregistered_filters: if filter.name in filters_required: diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index faa4bb6b..8c5b77cd 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -10,7 +10,7 @@ BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" -@filter_model.filter_class("BaiduCloudExamine") +@filter_model.filter_class("baidu-cloud-examine") class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index c94374c8..9391971c 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,7 +6,7 @@ from ....config import manager as cfg_mgr -@filter_model.filter_class("BanWordFilter") +@filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): """根据内容禁言""" diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index baafeef0..781f6397 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -5,7 +5,7 @@ from .. import filter as filter_model -@filter_model.filter_class("ContentIgnore") +@filter_model.filter_class("content-ignore") class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 2962ae28..2095845d 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -45,11 +45,14 @@ async def initialize(self): self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" - - if config['strategy'] == 'image': - self.strategy_impl = image.Text2ImageStrategy(self.ap) - elif config['strategy'] == 'forward': - self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + + for strategy_cls in strategy.preregistered_strategies: + if strategy_cls.name == config['strategy']: + self.strategy_impl = strategy_cls(self.ap) + break + else: + raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略") + await self.strategy_impl.initialize() async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index cfab49d9..4a790313 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -36,6 +36,7 @@ def __str__(self): return '[聊天记录]' +@strategy_model.strategy_class("forward") class ForwardComponentStrategy(strategy_model.LongTextStrategy): async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index af34f4e6..f96f03c5 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -15,6 +15,7 @@ from ....core import entities as core_entities +@strategy_model.strategy_class("image") class Text2ImageStrategy(strategy_model.LongTextStrategy): text_render_font: ImageFont.FreeTypeFont diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index a1f8a94f..296c5b4c 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -9,7 +9,30 @@ from ...core import entities as core_entities +preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] + + +def strategy_class( + name: str +) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: + def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]: + assert issubclass(cls, LongTextStrategy) + + cls.name = name + + preregistered_strategies.append(cls) + + return cls + + return decorator + + class LongTextStrategy(metaclass=abc.ABCMeta): + """长文本处理策略抽象类 + """ + + name: str + ap: app.Application def __init__(self, ap: app.Application): From a398c6f311f5120e539a665034f533fcf7a8e34e Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 8 Mar 2024 20:40:54 +0800 Subject: [PATCH 5/9] =?UTF-8?q?feat:=20=E6=B6=88=E6=81=AF=E5=B9=B3?= =?UTF-8?q?=E5=8F=B0=E9=80=82=E9=85=8D=E5=99=A8=E5=8F=AF=E6=89=A9=E5=B1=95?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/platform/adapter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 38c31fe2..5ce1db18 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -22,6 +22,8 @@ def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSour class MessageSourceAdapter(metaclass=abc.ABCMeta): + """消息平台适配器基类""" + name: str bot_account_id: int @@ -40,7 +42,7 @@ async def send_message( target_id: str, message: mirai.MessageChain ): - """发送消息 + """主动发送消息 Args: target_type (str): 目标类型,`person`或`group` From 1d963d0f0c3335a36a78b6c66bde1cf010a5ad29 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 12 Mar 2024 16:04:11 +0800 Subject: [PATCH 6/9] =?UTF-8?q?feat:=20=E4=B8=8D=E5=86=8D=E9=A2=84?= =?UTF-8?q?=E5=85=88=E8=AE=A1=E7=AE=97=E5=89=8D=E6=96=87token=E6=95=B0?= =?UTF-8?q?=E8=80=8C=E6=98=AF=E5=9C=A8=E6=8A=A5=E9=94=99=E6=97=B6=E6=8F=90?= =?UTF-8?q?=E9=86=92=E7=94=A8=E6=88=B7=E9=87=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 2 +- pkg/core/entities.py | 2 +- pkg/core/stages/build_app.py | 2 +- pkg/pipeline/preproc/preproc.py | 22 -------- pkg/pipeline/process/handlers/chat.py | 2 - .../{requester => modelmgr}/__init__.py | 0 pkg/provider/{requester => modelmgr}/api.py | 14 +++++ .../{requester => modelmgr}/apis/__init__.py | 0 .../{requester => modelmgr}/apis/chatcmpl.py | 6 ++- .../{requester => modelmgr}/entities.py | 6 +-- .../{requester => modelmgr}/errors.py | 0 .../{requester => modelmgr}/modelmgr.py | 53 +------------------ pkg/provider/{requester => modelmgr}/token.py | 0 pkg/provider/requester/tokenizer.py | 30 ----------- pkg/provider/requester/tokenizers/__init__.py | 0 pkg/provider/requester/tokenizers/tiktoken.py | 30 ----------- 16 files changed, 25 insertions(+), 144 deletions(-) rename pkg/provider/{requester => modelmgr}/__init__.py (100%) rename pkg/provider/{requester => modelmgr}/api.py (65%) rename pkg/provider/{requester => modelmgr}/apis/__init__.py (100%) rename pkg/provider/{requester => modelmgr}/apis/chatcmpl.py (94%) rename pkg/provider/{requester => modelmgr}/entities.py (76%) rename pkg/provider/{requester => modelmgr}/errors.py (100%) rename pkg/provider/{requester => modelmgr}/modelmgr.py (76%) rename pkg/provider/{requester => modelmgr}/token.py (100%) delete mode 100644 pkg/provider/requester/tokenizer.py delete mode 100644 pkg/provider/requester/tokenizers/__init__.py delete mode 100644 pkg/provider/requester/tokenizers/tiktoken.py diff --git a/pkg/core/app.py b/pkg/core/app.py index ed035e5d..0d726a44 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -6,7 +6,7 @@ from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr -from ..provider.requester import modelmgr as llm_model_mgr +from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 78bcf1fe..8bf1ff2e 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -9,7 +9,7 @@ import mirai from ..provider import entities as llm_entities -from ..provider.requester import entities +from ..provider.modelmgr import entities from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index a6c0fe3c..09b4342b 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -10,7 +10,7 @@ from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr -from ...provider.requester import modelmgr as llm_model_mgr +from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index c0eb92d6..cedc030f 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -51,28 +51,6 @@ async def process( query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt - # 根据模型max_tokens剪裁 - max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens']) - - test_messages = query.prompt.messages + query.messages + [query.user_message] - - while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens: - # 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错 - if len(query.prompt.messages) == 0: - return entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query, - user_notice='输入内容过长,请减少情景预设或者输入内容长度', - console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)' - ) - - query.messages.pop(0) # pop第一个肯定是role=user的 - # 继续pop到第二个role=user前一个 - while len(query.messages) > 0 and query.messages[0].role != 'user': - query.messages.pop(0) - - test_messages = query.prompt.messages + query.messages + [query.user_message] - return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index b3e8fa18..33dedb04 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -21,8 +21,6 @@ async def handle( ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - # 取session - # 取conversation # 调API # 生成器 diff --git a/pkg/provider/requester/__init__.py b/pkg/provider/modelmgr/__init__.py similarity index 100% rename from pkg/provider/requester/__init__.py rename to pkg/provider/modelmgr/__init__.py diff --git a/pkg/provider/requester/api.py b/pkg/provider/modelmgr/api.py similarity index 65% rename from pkg/provider/requester/api.py rename to pkg/provider/modelmgr/api.py index 88ba78cd..da362468 100644 --- a/pkg/provider/requester/api.py +++ b/pkg/provider/modelmgr/api.py @@ -7,9 +7,23 @@ from ...core import entities as core_entities from .. import entities as llm_entities + +preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] + +def requester_class(name: str): + + def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]: + cls.name = name + preregistered_requesters.append(cls) + return cls + + return decorator + + class LLMAPIRequester(metaclass=abc.ABCMeta): """LLM API请求器 """ + name: str = None ap: app.Application diff --git a/pkg/provider/requester/apis/__init__.py b/pkg/provider/modelmgr/apis/__init__.py similarity index 100% rename from pkg/provider/requester/apis/__init__.py rename to pkg/provider/modelmgr/apis/__init__.py diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py similarity index 94% rename from pkg/provider/requester/apis/chatcmpl.py rename to pkg/provider/modelmgr/apis/chatcmpl.py index 2d520017..4965acf7 100644 --- a/pkg/provider/requester/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -17,6 +17,7 @@ from ...tools import entities as tools_entities +@api.requester_class("openai-chat-completion") class OpenAIChatCompletion(api.LLMAPIRequester): """OpenAI ChatCompletion API 请求器""" @@ -133,7 +134,10 @@ async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, N except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: - raise errors.RequesterError(f'请求错误: {e.message}') + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + else: + raise errors.RequesterError(f'请求参数错误: {e.message}') except openai.AuthenticationError as e: raise errors.RequesterError(f'无效的 api-key: {e.message}') except openai.NotFoundError as e: diff --git a/pkg/provider/requester/entities.py b/pkg/provider/modelmgr/entities.py similarity index 76% rename from pkg/provider/requester/entities.py rename to pkg/provider/modelmgr/entities.py index d4c51d6f..277f125a 100644 --- a/pkg/provider/requester/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -5,7 +5,7 @@ import pydantic from . import api -from . import token, tokenizer +from . import token class LLMModelInfo(pydantic.BaseModel): @@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel): requester: api.LLMAPIRequester - tokenizer: 'tokenizer.LLMTokenizer' - tool_call_supported: typing.Optional[bool] = False - max_tokens: typing.Optional[int] = 2048 - class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/requester/errors.py b/pkg/provider/modelmgr/errors.py similarity index 100% rename from pkg/provider/requester/errors.py rename to pkg/provider/modelmgr/errors.py diff --git a/pkg/provider/requester/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py similarity index 76% rename from pkg/provider/requester/modelmgr.py rename to pkg/provider/modelmgr/modelmgr.py index e1a48bc2..a91c3110 100644 --- a/pkg/provider/requester/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -3,9 +3,8 @@ from . import entities from ...core import app -from .apis import chatcmpl from . import token -from .tokenizers import tiktoken +from .apis import chatcmpl class ModelManager: @@ -30,9 +29,7 @@ async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: async def initialize(self): openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) await openai_chat_completion.initialize() - openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys'])) - - tiktoken_tokenizer = tiktoken.Tiktoken(self.ap) + openai_token_mgr = token.TokenManager("openai", list(self.ap.provider_cfg.data['openai-config']['api-keys'])) model_list = [ entities.LLMModelInfo( @@ -40,48 +37,36 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ), entities.LLMModelInfo( name="gpt-3.5-turbo-1106", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-16k", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ), entities.LLMModelInfo( name="gpt-3.5-turbo-16k-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-0301", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ) ] @@ -93,64 +78,48 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-turbo-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-1106-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-vision-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="gpt-4-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="gpt-4-32k", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=32768 ), entities.LLMModelInfo( name="gpt-4-32k-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=32768 ) ] @@ -163,8 +132,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="OneAPI/chatglm_pro", @@ -172,8 +139,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/chatglm_std", @@ -181,8 +146,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/chatglm_lite", @@ -190,8 +153,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/qwen-v1", @@ -199,8 +160,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=6000 ), entities.LLMModelInfo( name="OneAPI/qwen-plus-v1", @@ -208,8 +167,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=30000 ), entities.LLMModelInfo( name="OneAPI/ERNIE-Bot", @@ -217,8 +174,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=2000 ), entities.LLMModelInfo( name="OneAPI/ERNIE-Bot-turbo", @@ -226,8 +181,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=7000 ), entities.LLMModelInfo( name="OneAPI/gemini-pro", @@ -235,8 +188,6 @@ async def initialize(self): token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=30720 ), ] diff --git a/pkg/provider/requester/token.py b/pkg/provider/modelmgr/token.py similarity index 100% rename from pkg/provider/requester/token.py rename to pkg/provider/modelmgr/token.py diff --git a/pkg/provider/requester/tokenizer.py b/pkg/provider/requester/tokenizer.py deleted file mode 100644 index cdd91470..00000000 --- a/pkg/provider/requester/tokenizer.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import abc -import typing - -from ...core import app -from .. import entities as llm_entities -from . import entities - - -class LLMTokenizer(metaclass=abc.ABCMeta): - """LLM分词器抽象类""" - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - """初始化分词器 - """ - pass - - @abc.abstractmethod - async def count_token( - self, - messages: list[llm_entities.Message], - model: entities.LLMModelInfo - ) -> int: - pass diff --git a/pkg/provider/requester/tokenizers/__init__.py b/pkg/provider/requester/tokenizers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/requester/tokenizers/tiktoken.py b/pkg/provider/requester/tokenizers/tiktoken.py deleted file mode 100644 index 24d2d8b6..00000000 --- a/pkg/provider/requester/tokenizers/tiktoken.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import tiktoken - -from .. import tokenizer -from ... import entities as llm_entities -from .. import entities - - -class Tiktoken(tokenizer.LLMTokenizer): - """TikToken分词器 - """ - - async def count_token( - self, - messages: list[llm_entities.Message], - model: entities.LLMModelInfo - ) -> int: - try: - encoding = tiktoken.encoding_for_model(model.name) - except KeyError: - # print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - num_tokens += len(encoding.encode(message.role)) - num_tokens += len(encoding.encode(message.content if message.content is not None else '')) - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens From 8c6ce1f030ac7d1376d22fb04c9ea7fa586ed943 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 12 Mar 2024 23:34:13 +0800 Subject: [PATCH 7/9] =?UTF-8?q?feat:=20=E7=BE=A4=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E8=A7=84=E5=88=99=E7=9A=84=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/resprule/resprule.py | 16 +++++++--------- pkg/pipeline/resprule/rule.py | 12 ++++++++++++ pkg/pipeline/resprule/rules/atbot.py | 1 + pkg/pipeline/resprule/rules/prefix.py | 1 + pkg/pipeline/resprule/rules/random.py | 1 + pkg/pipeline/resprule/rules/regexp.py | 1 + 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 8f418729..d795d056 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): async def initialize(self): """初始化检查器 """ - self.rule_matchers = [ - atbot.AtBotRule(self.ap), - prefix.PrefixRule(self.ap), - regexp.RegExpRule(self.ap), - random.RandomRespRule(self.ap), - ] - - for rule_matcher in self.rule_matchers: - await rule_matcher.initialize() + + self.rule_matchers = [] + + for rule_matcher in rule.preregisetered_rules: + rule_inst = rule_matcher(self.ap) + await rule_inst.initialize() + self.rule_matchers.append(rule_inst) async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index cde9ec3d..bfab4152 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -1,5 +1,6 @@ from __future__ import annotations import abc +import typing import mirai @@ -7,9 +8,20 @@ from . import entities +preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] + +def rule_class(name: str): + def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]: + cls.name = name + preregisetered_rules.append(cls) + return cls + return decorator + + class GroupRespondRule(metaclass=abc.ABCMeta): """群组响应规则的抽象类 """ + name: str ap: app.Application diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 692bee72..293cfd96 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -7,6 +7,7 @@ from ....core import entities as core_entities +@rule_model.rule_class("at-bot") class AtBotRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index 1b61c138..99dcd4f9 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -5,6 +5,7 @@ from ....core import entities as core_entities +@rule_model.rule_class("prefix") class PrefixRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 185e03ec..80acf6a5 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -7,6 +7,7 @@ from ....core import entities as core_entities +@rule_model.rule_class("random") class RandomRespRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index 4e39d432..aaa46449 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -7,6 +7,7 @@ from ....core import entities as core_entities +@rule_model.rule_class("regexp") class RegExpRule(rule_model.GroupRespondRule): async def match( From b9fa11c0c39236691d018d248d0391a8a1792a07 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Tue, 12 Mar 2024 16:22:07 +0000 Subject: [PATCH 8/9] =?UTF-8?q?feat:=20prompt=20=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E5=99=A8=E7=9A=84=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/sysprompt/loader.py | 14 ++++++++++++++ pkg/provider/sysprompt/loaders/scenario.py | 1 + pkg/provider/sysprompt/loaders/single.py | 1 + pkg/provider/sysprompt/sysprompt.py | 14 ++++++++------ 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py index ca9e8730..9e0a6144 100644 --- a/pkg/provider/sysprompt/loader.py +++ b/pkg/provider/sysprompt/loader.py @@ -1,13 +1,27 @@ from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_loaders: list[typing.Type[PromptLoader]] = [] + +def loader_class(name: str): + + def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: + cls.name = name + preregistered_loaders.append(cls) + return cls + + return decorator + + class PromptLoader(metaclass=abc.ABCMeta): """Prompt加载器抽象类 """ + name: str ap: app.Application diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py index a559ff73..9c19d963 100644 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ b/pkg/provider/sysprompt/loaders/scenario.py @@ -8,6 +8,7 @@ from ....provider import entities as llm_entities +@loader.loader_class("full_scenario") class ScenarioPromptLoader(loader.PromptLoader): """加载scenario目录下的json""" diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py index 57e06ed2..3ac9c262 100644 --- a/pkg/provider/sysprompt/loaders/single.py +++ b/pkg/provider/sysprompt/loaders/single.py @@ -6,6 +6,7 @@ from ....provider import entities as llm_entities +@loader.loader_class("normal") class SingleSystemPromptLoader(loader.PromptLoader): """配置文件中的单条system prompt的prompt加载器 """ diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index eb89e8ab..61c598ed 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -20,12 +20,14 @@ def __init__(self, ap: app.Application): async def initialize(self): - loader_map = { - "normal": single.SingleSystemPromptLoader, - "full_scenario": scenario.ScenarioPromptLoader - } - - loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] + mode_name = self.ap.provider_cfg.data['prompt-mode'] + + for loader_cls in loader.preregistered_loaders: + if loader_cls.name == mode_name: + loader_cls = loader_cls + break + else: + raise ValueError(f'未知的 Prompt 加载器: {mode_name}') self.loader_inst: loader.PromptLoader = loader_cls(self.ap) From 13393b66242519ab385075264d923ee98bb3df49 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Tue, 12 Mar 2024 16:31:54 +0000 Subject: [PATCH 9/9] =?UTF-8?q?feat:=20=E9=99=90=E9=80=9F=E7=AE=97?= =?UTF-8?q?=E6=B3=95=E7=9A=84=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/ratelimit/algo.py | 15 +++++++++++++++ pkg/pipeline/ratelimit/algos/fixedwin.py | 1 + pkg/pipeline/ratelimit/ratelimit.py | 14 +++++++++++++- pkg/provider/sysprompt/sysprompt.py | 6 ++++-- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index b6d9ba7b..448ae384 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -1,11 +1,26 @@ from __future__ import annotations import abc +import typing from ...core import app +preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] + +def algo_class(name: str): + + def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]: + cls.name = name + preregistered_algos.append(cls) + return cls + + return decorator + + class ReteLimitAlgo(metaclass=abc.ABCMeta): + name: str = None + ap: app.Application def __init__(self, ap: app.Application): diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index bb69b0dd..aa380291 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -19,6 +19,7 @@ def __init__(self): self.records = {} +@algo.algo_class("fixwin") class FixedWindowAlgo(algo.ReteLimitAlgo): containers_lock: asyncio.Lock diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index a9e29799..f43c8b06 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage): algo: algo.ReteLimitAlgo async def initialize(self): - self.algo = fixedwin.FixedWindowAlgo(self.ap) + + algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] + + algo_class = None + + for algo_cls in algo.preregistered_algos: + if algo_cls.name == algo_name: + algo_class = algo_cls + break + else: + raise ValueError(f'未知的限速算法: {algo_name}') + + self.algo = algo_class(self.ap) await self.algo.initialize() async def process( diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index 61c598ed..c7695f5a 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -22,14 +22,16 @@ async def initialize(self): mode_name = self.ap.provider_cfg.data['prompt-mode'] + loader_class = None + for loader_cls in loader.preregistered_loaders: if loader_cls.name == mode_name: - loader_cls = loader_cls + loader_class = loader_cls break else: raise ValueError(f'未知的 Prompt 加载器: {mode_name}') - self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + self.loader_inst: loader.PromptLoader = loader_class(self.ap) await self.loader_inst.initialize() await self.loader_inst.load()