Skip to content

Commit

Permalink
Merge pull request #712 from RockChinQ/feat/component-extensibility
Browse files Browse the repository at this point in the history
Feat: 更多组件的可扩展性
  • Loading branch information
RockChinQ authored Mar 12, 2024
2 parents 0ee383b + 13393b6 commit 63303bb
Show file tree
Hide file tree
Showing 43 changed files with 209 additions and 193 deletions.
20 changes: 18 additions & 2 deletions pkg/command/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,34 @@


preregistered_operators: list[typing.Type[CommandOperator]] = []
"""预注册算子列表。在初始化时,所有算子类会被注册到此列表中。"""
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""


def operator_class(
name: str,
help: str,
help: str = "",
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户,2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
usage (str, optional): 使用说明. Defaults to None.
alias (list[str], optional): 别名. Defaults to [].
privilege (int, optional): 权限,1为普通用户可用,2为仅管理员可用. Defaults to 1.
parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""

def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)

cls.name = name
cls.alias = alias
cls.help = help
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import mirai

from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/stages/build_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.requester import modelmgr as llm_model_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr
Expand Down
21 changes: 15 additions & 6 deletions pkg/pipeline/cntfilter/cntfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine


Expand All @@ -16,20 +16,29 @@
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""

filter_chain: list[filter.ContentFilter]
filter_chain: list[filter_model.ContentFilter]

def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)

async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))

filters_required = [
"content-filter"
]

if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
filters_required.append("ban-word-filter")

if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
filters_required.append("baidu-cloud-examine")

for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)

for filter in self.filter_chain:
await filter.initialize()
Expand Down
30 changes: 30 additions & 0 deletions pkg/pipeline/cntfilter/filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
import typing

from ...core import app
from . import entities


preregistered_filters: list[typing.Type[ContentFilter]] = []


def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)

cls.name = name

preregistered_filters.append(cls)

return cls

return decorator


class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""

name: str

ap: app.Application

Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/cntfilter/filters/baiduexamine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"


@filter_model.filter_class("baidu-cloud-examine")
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""

Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/cntfilter/filters/banwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ....config import manager as cfg_mgr


@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""

Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/cntfilter/filters/cntignore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import filter as filter_model


@filter_model.filter_class("content-ignore")
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""

Expand Down
13 changes: 8 additions & 5 deletions pkg/pipeline/longtext/longtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ async def initialize(self):
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))

self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"

if config['strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif config['strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)

for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
self.strategy_impl = strategy_cls(self.ap)
break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")

await self.strategy_impl.initialize()

async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/longtext/strategies/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __str__(self):
return '[聊天记录]'


@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy):

async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/longtext/strategies/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ....core import entities as core_entities


@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):

text_render_font: ImageFont.FreeTypeFont
Expand Down
23 changes: 23 additions & 0 deletions pkg/pipeline/longtext/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@
from ...core import entities as core_entities


preregistered_strategies: list[typing.Type[LongTextStrategy]] = []


def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)

cls.name = name

preregistered_strategies.append(cls)

return cls

return decorator


class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""

name: str

ap: app.Application

def __init__(self, ap: app.Application):
Expand Down
22 changes: 0 additions & 22 deletions pkg/pipeline/preproc/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,6 @@ async def process(
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt

# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])

test_messages = query.prompt.messages + query.messages + [query.user_message]

while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)'
)

query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)

test_messages = query.prompt.messages + query.messages + [query.user_message]

return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
Expand Down
2 changes: 0 additions & 2 deletions pkg/pipeline/process/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ async def handle(
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器

Expand Down
15 changes: 15 additions & 0 deletions pkg/pipeline/ratelimit/algo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from __future__ import annotations
import abc
import typing

from ...core import app


preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []

def algo_class(name: str):

def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls

return decorator


class ReteLimitAlgo(metaclass=abc.ABCMeta):

name: str = None

ap: app.Application

def __init__(self, ap: app.Application):
Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/ratelimit/algos/fixedwin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self):
self.records = {}


@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo):

containers_lock: asyncio.Lock
Expand Down
14 changes: 13 additions & 1 deletion pkg/pipeline/ratelimit/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo

async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)

algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']

algo_class = None

for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')

self.algo = algo_class(self.ap)
await self.algo.initialize()

async def process(
Expand Down
16 changes: 7 additions & 9 deletions pkg/pipeline/resprule/resprule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]

for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()

self.rule_matchers = []

for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)

async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

Expand Down
12 changes: 12 additions & 0 deletions pkg/pipeline/resprule/rule.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from __future__ import annotations
import abc
import typing

import mirai

from ...core import app, entities as core_entities
from . import entities


preregisetered_rules: list[typing.Type[GroupRespondRule]] = []

def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator


class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
name: str

ap: app.Application

Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/resprule/rules/atbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ....core import entities as core_entities


@rule_model.rule_class("at-bot")
class AtBotRule(rule_model.GroupRespondRule):

async def match(
Expand Down
Loading

0 comments on commit 63303bb

Please sign in to comment.