From fa823de6b000ae8fdf79e74b0877343a8d21426f Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 20 Mar 2024 14:20:56 +0800 Subject: [PATCH 1/8] =?UTF-8?q?perf:=20=E5=88=9D=E5=A7=8B=E5=8C=96config?= =?UTF-8?q?=E5=AF=B9=E8=B1=A1=E6=97=B6=E6=94=AF=E6=8C=81=E4=BC=A0=E9=80=92?= =?UTF-8?q?dict=E4=BD=9C=E4=B8=BA=E6=A8=A1=E6=9D=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/config/impls/json.py | 32 +++++++++++++++----------------- pkg/config/manager.py | 5 +++-- pkg/config/model.py | 3 +++ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index aecf686f..a9aab6a9 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -8,15 +8,12 @@ class JSONConfigFile(file_model.ConfigFile): """JSON配置文件""" - config_file_name: str = None - """配置文件名""" - - template_file_name: str = None - """模板文件名""" - - def __init__(self, config_file_name: str, template_file_name: str) -> None: + def __init__( + self, config_file_name: str, template_file_name: str = None, template_data: dict = None + ) -> None: self.config_file_name = config_file_name self.template_file_name = template_file_name + self.template_data = template_data def exists(self) -> bool: return os.path.exists(self.config_file_name) @@ -29,23 +26,24 @@ async def load(self) -> dict: if not self.exists(): await self.create() - with open(self.config_file_name, 'r', encoding='utf-8') as f: - cfg = json.load(f) + if self.template_file_name is not None: + with open(self.config_file_name, "r", encoding="utf-8") as f: + cfg = json.load(f) # 从模板文件中进行补全 - with open(self.template_file_name, 'r', encoding='utf-8') as f: - template_cfg = json.load(f) + with open(self.template_file_name, "r", encoding="utf-8") as f: + self.template_data = json.load(f) - for key in template_cfg: + for key in self.template_data: if key not in cfg: - cfg[key] = template_cfg[key] + cfg[key] = self.template_data[key] return cfg - + async def save(self, cfg: dict): - with open(self.config_file_name, 'w', encoding='utf-8') as f: + with open(self.config_file_name, "w", encoding="utf-8") as f: json.dump(cfg, f, indent=4, ensure_ascii=False) def save_sync(self, cfg: dict): - with open(self.config_file_name, 'w', encoding='utf-8') as f: - json.dump(cfg, f, indent=4, ensure_ascii=False) \ No newline at end of file + with open(self.config_file_name, "w", encoding="utf-8") as f: + json.dump(cfg, f, indent=4, ensure_ascii=False) diff --git a/pkg/config/manager.py b/pkg/config/manager.py index aae827e1..f9e93c81 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -43,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con return cfg_mgr -async def load_json_config(config_name: str, template_name: str) -> ConfigManager: +async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: """加载JSON配置文件""" cfg_inst = json_file.JSONConfigFile( config_name, - template_name + template_name, + template_data ) cfg_mgr = ConfigManager(cfg_inst) diff --git a/pkg/config/model.py b/pkg/config/model.py index 9be6f0f6..d209093c 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta): template_file_name: str = None """模板文件名""" + template_data: dict = None + """模板数据""" + @abc.abstractmethod def exists(self) -> bool: pass From 52a7c25540dd78be87ac63ea7053c92ad270f6e4 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 20 Mar 2024 15:09:47 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat:=20=E5=BC=82=E6=AD=A5=E9=A3=8E?= =?UTF-8?q?=E6=A0=BC=E6=8F=92=E4=BB=B6=E6=96=B9=E6=B3=95=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/bootutils/deps.py | 1 - pkg/plugin/context.py | 13 +++ pkg/plugin/events.py | 2 + pkg/plugin/loaders/{legacy.py => classic.py} | 53 ++++++++- pkg/plugin/manager.py | 7 +- pkg/plugin/models.py | 12 ++ pkg/provider/tools/toolmgr.py | 14 ++- pkg/utils/funcschema.py | 116 +++++++++++++++++++ requirements.txt | 1 - 9 files changed, 210 insertions(+), 9 deletions(-) rename pkg/plugin/loaders/{legacy.py => classic.py} (74%) create mode 100644 pkg/utils/funcschema.py diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 41097f27..3b44e9cc 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -10,7 +10,6 @@ "botpy": "qq-botpy", "PIL": "pillow", "nakuru": "nakuru-project-idk", - "CallingGPT": "CallingGPT", "tiktoken": "tiktoken", "yaml": "pyyaml", "aiohttp": "aiohttp", diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index a982232f..329914f0 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta): """插件基类""" host: APIHost + """API宿主""" + + ap: app.Application + """应用程序对象""" + + def __init__(self, host: APIHost): + self.host = host + + async def initialize(self): + """初始化插件""" + pass class APIHost: @@ -61,8 +72,10 @@ class EventContext: """事件编号""" host: APIHost = None + """API宿主""" event: events.BaseEventModel = None + """此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义""" __prevent_default__ = False """是否阻止默认行为""" diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index fe67a82d..bcb8e5c8 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -10,8 +10,10 @@ class BaseEventModel(pydantic.BaseModel): + """事件模型基类""" query: typing.Union[core_entities.Query, None] + """此次请求的query对象,可能为None""" class Config: arbitrary_types_allowed = True diff --git a/pkg/plugin/loaders/legacy.py b/pkg/plugin/loaders/classic.py similarity index 74% rename from pkg/plugin/loaders/legacy.py rename to pkg/plugin/loaders/classic.py index 9bbee7c0..d5be6ace 100644 --- a/pkg/plugin/loaders/legacy.py +++ b/pkg/plugin/loaders/classic.py @@ -5,11 +5,10 @@ import importlib import traceback -from CallingGPT.entities.namespace import get_func_schema - from .. import loader, events, context, models, host from ...core import entities as core_entities from ...provider.tools import entities as tools_entities +from ...utils import funcschema class PluginLoader(loader.PluginLoader): @@ -29,6 +28,9 @@ async def initialize(self): setattr(models, 'on', self.on) setattr(models, 'func', self.func) + setattr(models, 'handler', self.handler) + setattr(models, 'llm_func', self.llm_func) + def register( self, name: str, @@ -57,6 +59,8 @@ def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]: return wrapper + # 过时 + # 最早将于 v3.4 版本移除 def on( self, event: typing.Type[events.BaseEventModel] @@ -83,6 +87,8 @@ async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None return wrapper + # 过时 + # 最早将于 v3.4 版本移除 def func( self, name: str=None, @@ -91,10 +97,11 @@ def func( self.ap.logger.debug(f'注册内容函数 {name}') def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = get_func_schema(func) + function_schema = funcschema.get_func_schema(func) function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) async def handler( + plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs @@ -116,6 +123,46 @@ async def handler( return wrapper + def handler( + self, + event: typing.Type[events.BaseEventModel] + ) -> typing.Callable[[typing.Callable], typing.Callable]: + """注册事件处理器""" + self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: + + self._current_container.event_handlers[event] = func + + return func + + return wrapper + + def llm_func( + self, + name: str=None, + ) -> typing.Callable: + """注册内容函数""" + self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: + + function_schema = funcschema.get_func_schema(func) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + + llm_function = tools_entities.LLMFunction( + name=function_name, + human_desc='', + description=function_schema['description'], + enable=True, + parameters=function_schema['parameters'], + func=func, + ) + + self._current_container.content_functions.append(llm_function) + + return func + + return wrapper + async def _walk_plugin_path( self, module, diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index 06e94f98..13a114a5 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -5,7 +5,7 @@ from ..core import app from . import context, loader, events, installer, setting, models -from .loaders import legacy +from .loaders import classic from .installers import github @@ -26,7 +26,7 @@ class PluginManager: def __init__(self, ap: app.Application): self.ap = ap - self.loader = legacy.PluginLoader(ap) + self.loader = classic.PluginLoader(ap) self.installer = github.GitHubRepoInstaller(ap) self.setting = setting.SettingManager(ap) self.api_host = context.APIHost(ap) @@ -52,6 +52,9 @@ async def initialize_plugins(self): for plugin in self.plugins: try: plugin.plugin_inst = plugin.plugin_class(self.api_host) + plugin.plugin_inst.ap = self.ap + plugin.plugin_inst.host = self.api_host + await plugin.plugin_inst.initialize() except Exception as e: self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.exception(e) diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 972eed11..642305e4 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -24,3 +24,15 @@ def func( name: str=None, ) -> typing.Callable: pass + + +def handler( + event: typing.Type[BaseEventModel] +) -> typing.Callable[[typing.Callable], typing.Callable]: + pass + + +def llm_func( + name: str=None, +) -> typing.Callable: + pass \ No newline at end of file diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 72c892bb..616de713 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -5,6 +5,7 @@ from ...core import app, entities as core_entities from . import entities +from ...plugin import context as plugin_context class ToolManager: @@ -28,6 +29,15 @@ async def get_function(self, name: str) -> entities.LLMFunction: return function return None + async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: + """获取函数和插件 + """ + for plugin in self.ap.plugin_mgr.plugins: + for function in plugin.content_functions: + if function.name == name: + return function, plugin + return None, None + async def get_all_functions(self) -> list[entities.LLMFunction]: """获取所有函数 """ @@ -68,7 +78,7 @@ async def execute_func_call( try: - function = await self.get_function(name) + function, plugin = await self.get_function_and_plugin(name) if function is None: return None @@ -79,7 +89,7 @@ async def execute_func_call( **parameters } - return await function.func(**parameters) + return await function.func(plugin, **parameters) except Exception as e: self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') traceback.print_exc() diff --git a/pkg/utils/funcschema.py b/pkg/utils/funcschema.py new file mode 100644 index 00000000..c39b4886 --- /dev/null +++ b/pkg/utils/funcschema.py @@ -0,0 +1,116 @@ +import sys +import re +import inspect + + +def get_func_schema(function: callable) -> dict: + """ + Return the data schema of a function. + { + "function": function, + "description": "function description", + "parameters": { + "type": "object", + "properties": { + "parameter_a": { + "type": "str", + "description": "parameter_a description" + }, + "parameter_b": { + "type": "int", + "description": "parameter_b description" + }, + "parameter_c": { + "type": "str", + "description": "parameter_c description", + "enum": ["a", "b", "c"] + }, + }, + "required": ["parameter_a", "parameter_b"] + } + } + """ + func_doc = function.__doc__ + # Google Style Docstring + if func_doc is None: + raise Exception("Function {} has no docstring.".format(function.__name__)) + func_doc = func_doc.strip().replace(' ','').replace('\t', '') + # extract doc of args from docstring + doc_spt = func_doc.split('\n\n') + desc = doc_spt[0] + args = doc_spt[1] if len(doc_spt) > 1 else "" + returns = doc_spt[2] if len(doc_spt) > 2 else "" + + # extract args + # delete the first line of args + arg_lines = args.split('\n')[1:] + arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args) + args_doc = {} + for arg_line in arg_lines: + doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line) + if len(doc_tuple) == 0: + continue + args_doc[doc_tuple[0][0]] = doc_tuple[0][3] + + # extract returns + return_doc_list = re.findall(r'(\w+):\s*(.*)', returns) + + params = enumerate(inspect.signature(function).parameters.values()) + parameters = { + "type": "object", + "required": [], + "properties": {}, + } + + + for i, param in params: + + # 排除 self, query + if param.name in ['self', 'query']: + continue + + param_type = param.annotation.__name__ + + type_name_mapping = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "list": "array", + "dict": "object", + } + + if param_type in type_name_mapping: + param_type = type_name_mapping[param_type] + + parameters['properties'][param.name] = { + "type": param_type, + "description": args_doc[param.name], + } + + # add schema for array + if param_type == "array": + # extract type of array, the int of list[int] + # use re + array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation)) + + array_type = 'string' + + if len(array_type_tuple) > 0: + array_type = array_type_tuple[0] + + if array_type in type_name_mapping: + array_type = type_name_mapping[array_type] + + parameters['properties'][param.name]["items"] = { + "type": array_type, + } + + if param.default is inspect.Parameter.empty: + parameters["required"].append(param.name) + + return { + "function": function, + "description": desc, + "parameters": parameters, + } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6a3e718c..28c0ecb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ aiocqhttp qq-botpy nakuru-project-idk Pillow -CallingGPT tiktoken PyYaml aiohttp From 9855c6b8f5c99499823fa85e372ac1487a00d0c5 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 20 Mar 2024 15:48:11 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20=E6=96=B0=E7=9A=84=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 3 +-- pkg/plugin/context.py | 20 ++++++++++++++++++++ pkg/plugin/host.py | 4 ++++ pkg/plugin/loaders/classic.py | 5 +++-- pkg/plugin/models.py | 16 ++++------------ 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pkg/core/app.py b/pkg/core/app.py index f21ed5b0..7c9868bf 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -85,8 +85,7 @@ async def run(self): tasks = [] try: - - + tasks = [ asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.ctrl.run()) diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 329914f0..3907489d 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -9,6 +9,26 @@ from ..core import app +def register( + name: str, + description: str, + version: str, + author +) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]: + pass + +def handler( + event: typing.Type[events.BaseEventModel] +) -> typing.Callable[[typing.Callable], typing.Callable]: + pass + + +def llm_func( + name: str=None, +) -> typing.Callable: + pass + + class BasePlugin(metaclass=abc.ABCMeta): """插件基类""" diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 6149da62..2868875d 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -1,3 +1,7 @@ +# 此模块已过时 +# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost +# 最早将于 v3.4 移除此模块 + from . events import * from . context import EventContext, APIHost as PluginHost diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index d5be6ace..b5a733f8 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -28,8 +28,9 @@ async def initialize(self): setattr(models, 'on', self.on) setattr(models, 'func', self.func) - setattr(models, 'handler', self.handler) - setattr(models, 'llm_func', self.llm_func) + setattr(context, 'register', self.register) + setattr(context, 'handler', self.handler) + setattr(context, 'llm_func', self.llm_func) def register( self, diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 642305e4..b8b499f5 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -1,3 +1,7 @@ +# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数 +# 各个事件模型请从 pkg.plugin.events 引入 +# 最早将于 v3.4 移除此模块 + from __future__ import annotations import typing @@ -24,15 +28,3 @@ def func( name: str=None, ) -> typing.Callable: pass - - -def handler( - event: typing.Type[BaseEventModel] -) -> typing.Callable[[typing.Callable], typing.Callable]: - pass - - -def llm_func( - name: str=None, -) -> typing.Callable: - pass \ No newline at end of file From 0752698c1d654afd6e0c047b96659213985ec9b3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 20 Mar 2024 18:43:52 +0800 Subject: [PATCH 4/8] =?UTF-8?q?chore:=20=E5=AE=8C=E5=96=84plugin=E5=AF=B9?= =?UTF-8?q?=E5=A4=96=E5=AF=B9=E8=B1=A1=E7=9A=84=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/plugin/context.py | 47 ++++++++++++++++++++++++++++++++++++++++++- pkg/plugin/events.py | 2 +- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 3907489d..b0e2ef96 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -13,19 +13,64 @@ def register( name: str, description: str, version: str, - author + author: str ) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]: + """注册插件类 + + 使用示例: + + @register( + name="插件名称", + description="插件描述", + version="插件版本", + author="插件作者" + ) + class MyPlugin(BasePlugin): + pass + """ pass def handler( event: typing.Type[events.BaseEventModel] ) -> typing.Callable[[typing.Callable], typing.Callable]: + """注册事件监听器 + + 使用示例: + + class MyPlugin(BasePlugin): + + @handler(NormalMessageResponded) + async def on_normal_message_responded(self, ctx: EventContext): + pass + """ pass def llm_func( name: str=None, ) -> typing.Callable: + """注册内容函数 + + 使用示例: + + class MyPlugin(BasePlugin): + + @llm_func("access_the_web_page") + async def _(self, query, url: str, brief_len: int): + \"""Call this function to search about the question before you answer any questions. + - Do not search through google.com at any time. + - If you need to search somthing, visit https://www.sogou.com/web?query=. + - If user ask you to open a url (start with http:// or https://), visit it directly. + - Summary the plain content result by yourself, DO NOT directly output anything in the result you got. + + Args: + url(str): url to visit + brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096 + + Returns: + str: plain text content of the web page or error message(starts with 'error:') + \""" + """ pass diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index bcb8e5c8..b5919762 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -13,7 +13,7 @@ class BaseEventModel(pydantic.BaseModel): """事件模型基类""" query: typing.Union[core_entities.Query, None] - """此次请求的query对象,可能为None""" + """此次请求的query对象,非请求过程的事件时为None""" class Config: arbitrary_types_allowed = True From d0b0f2209aa5ef1ce8be83e2834eb98e6447c593 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 20 Mar 2024 23:32:28 +0800 Subject: [PATCH 5/8] =?UTF-8?q?fix:=20chat=E5=A4=84=E7=90=86=E8=BF=87?= =?UTF-8?q?=E7=A8=8B=E7=9A=84=E6=8F=92=E4=BB=B6=E8=BF=94=E5=9B=9E=E5=80=BC?= =?UTF-8?q?=E7=9B=AE=E6=A0=87=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/process/handlers/chat.py | 9 ++++++++- pkg/pipeline/wrapper/wrapper.py | 7 +++++++ pkg/provider/entities.py | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 33dedb04..58b7a830 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -39,7 +39,14 @@ async def handle( if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + mc = mirai.MessageChain(event_ctx.event.reply) + + query.resp_messages.append( + llm_entities.Message( + role='plugin', + content=str(mc), + ) + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 0278b603..80277a0f 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -29,6 +29,13 @@ async def process( if query.resp_messages[-1].role == 'command': query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif query.resp_messages[-1].role == 'plugin': + query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) + yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 2a555311..0dffc636 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -22,7 +22,7 @@ class ToolCall(pydantic.BaseModel): class Message(pydantic.BaseModel): """消息""" - role: str # user, system, assistant, tool, command + role: str # user, system, assistant, tool, command, plugin name: typing.Optional[str] = None From 5f138de75bc88da8235234e034def68d29b27d14 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 22 Mar 2024 11:05:58 +0800 Subject: [PATCH 6/8] =?UTF-8?q?doc:=20=E5=AE=8C=E5=96=84query=E5=AF=B9?= =?UTF-8?q?=E8=B1=A1=E7=9A=84=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/entities.py | 26 +++++++++++++------------- pkg/platform/manager.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 8bf1ff2e..f0f3f151 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -32,43 +32,43 @@ class Query(pydantic.BaseModel): """请求ID,添加进请求池时生成""" launcher_type: LauncherTypes - """会话类型,platform设置""" + """会话类型,platform处理阶段设置""" launcher_id: int - """会话ID,platform设置""" + """会话ID,platform处理阶段设置""" sender_id: int - """发送者ID,platform设置""" + """发送者ID,platform处理阶段设置""" message_event: mirai.MessageEvent - """事件,platform收到的事件""" + """事件,platform收到的原始事件""" message_chain: mirai.MessageChain - """消息链,platform收到的消息链""" + """消息链,platform收到的原始消息链""" adapter: msadapter.MessageSourceAdapter - """适配器对象""" + """消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器""" session: typing.Optional[Session] = None - """会话对象,由前置处理器设置""" + """会话对象,由前置处理器阶段设置""" messages: typing.Optional[list[llm_entities.Message]] = [] - """历史消息列表,由前置处理器设置""" + """历史消息列表,由前置处理器阶段设置""" prompt: typing.Optional[sysprompt_entities.Prompt] = None - """情景预设内容,由前置处理器设置""" + """情景预设内容,由前置处理器阶段设置""" user_message: typing.Optional[llm_entities.Message] = None - """此次请求的用户消息对象,由前置处理器设置""" + """此次请求的用户消息对象,由前置处理器阶段设置""" use_model: typing.Optional[entities.LLMModelInfo] = None - """使用的模型,由前置处理器设置""" + """使用的模型,由前置处理器阶段设置""" use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None - """使用的函数,由前置处理器设置""" + """使用的函数,由前置处理器阶段设置""" resp_messages: typing.Optional[list[llm_entities.Message]] = [] - """由provider生成的回复消息对象列表""" + """由Process阶段生成的回复消息对象列表""" resp_message_chain: typing.Optional[mirai.MessageChain] = None """回复消息链,从resp_messages包装而得""" diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 7b40f2ab..fb23cb5a 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -146,7 +146,7 @@ async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSource if len(self.adapters) == 0: self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') - async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): + async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): From bd6a32e08e1bd6a6d0d2d566e8d247148843ca2a Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 22 Mar 2024 16:41:46 +0800 Subject: [PATCH 7/8] =?UTF-8?q?doc:=20=E4=B8=BA=E5=8F=AF=E6=89=A9=E5=B1=95?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E6=B7=BB=E5=8A=A0=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/entities.py | 25 +++++++++++++++++++++++++ pkg/command/operator.py | 20 +++++++++++++++++++- pkg/config/migration.py | 1 + pkg/core/stage.py | 5 ++++- pkg/pipeline/cntfilter/entities.py | 15 ++++++++++++--- pkg/pipeline/cntfilter/filter.py | 14 ++++++++++++++ pkg/pipeline/longtext/strategy.py | 20 ++++++++++++++++++++ pkg/pipeline/ratelimit/algo.py | 19 +++++++++++++++++++ pkg/platform/adapter.py | 15 +++++++++++++++ pkg/provider/entities.py | 5 +++++ pkg/provider/modelmgr/api.py | 11 ++++++++++- pkg/provider/sysprompt/entities.py | 2 ++ pkg/provider/sysprompt/loader.py | 2 +- 13 files changed, 147 insertions(+), 7 deletions(-) diff --git a/pkg/command/entities.py b/pkg/command/entities.py index ee698b24..27cb5962 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -20,6 +20,8 @@ class CommandReturn(pydantic.BaseModel): image: typing.Optional[mirai.Image] error: typing.Optional[errors.CommandError]= None + """错误 + """ class Config: arbitrary_types_allowed = True @@ -30,17 +32,40 @@ class ExecuteContext(pydantic.BaseModel): """ query: core_entities.Query + """本次消息的请求对象""" session: core_entities.Session + """本次消息所属的会话对象""" command_text: str + """命令完整文本""" command: str + """命令名称""" crt_command: str + """当前命令 + + 多级命令中crt_command为当前命令,command为根命令。 + 例如:!plugin on Webwlkr + 处理到plugin时,command为plugin,crt_command为plugin + 处理到on时,command为plugin,crt_command为on + """ params: list[str] + """命令参数 + + 整个命令以空格分割后的参数列表 + """ crt_params: list[str] + """当前命令参数 + + 多级命令中crt_params为当前命令参数,params为根命令参数。 + 例如:!plugin on Webwlkr + 处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr'] + 处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr'] + """ privilege: int + """发起人权限""" diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 307e9fbe..5e3b1a8f 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -52,6 +52,11 @@ def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator] class CommandOperator(metaclass=abc.ABCMeta): """命令算子抽象类 + + 以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。 + 命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。 + 处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。 + 若没有匹配成功或没有子命令,则执行当前命令。 """ ap: app.Application @@ -60,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta): """名称,搜索到时若符合则使用""" path: str - """路径,所有父节点的name的连接,用于定义命令权限""" + """路径,所有父节点的name的连接,用于定义命令权限,由管理器在初始化时自动设置。 + """ alias: list[str] """同name""" @@ -69,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta): """此节点的帮助信息""" usage: str = None + """用法""" parent_class: typing.Union[typing.Type[CommandOperator], None] = None """父节点类。标记以供管理器在初始化时编织父子关系。""" @@ -92,4 +99,15 @@ async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """实现此方法以执行命令 + + 支持多次yield以返回多个结果。 + 例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。 + + Args: + context (entities.ExecuteContext): 命令执行上下文 + + Yields: + entities.CommandReturn: 命令返回封装 + """ pass diff --git a/pkg/config/migration.py b/pkg/config/migration.py index 3a6650b2..e84a59cb 100644 --- a/pkg/config/migration.py +++ b/pkg/config/migration.py @@ -7,6 +7,7 @@ preregistered_migrations: list[typing.Type[Migration]] = [] +"""当前阶段暂不支持扩展""" def migration_class(name: str, number: int): """注册一个迁移 diff --git a/pkg/core/stage.py b/pkg/core/stage.py index a4ff7af1..f1c65295 100644 --- a/pkg/core/stage.py +++ b/pkg/core/stage.py @@ -7,7 +7,10 @@ preregistered_stages: dict[str, typing.Type[BootingStage]] = {} -"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。""" +"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。 + +当前阶段暂不支持扩展 +""" def stage_class( name: str diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 7ab05675..8ff581fb 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -31,15 +31,24 @@ class EnableStage(enum.Enum): class FilterResult(pydantic.BaseModel): level: ResultLevel + """结果等级 + + 对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。 + 对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。 + """ replacement: str - """替换后的消息""" + """替换后的消息 + + 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 + 若没有修改内容,也需要返回原消息。 + """ user_notice: str - """不通过时,用户提示消息""" + """不通过时,若此值不为空,将对用户提示消息""" console_notice: str - """不通过时,控制台提示消息""" + """不通过时,若此值不为空,将在控制台提示消息""" class ManagerResultLevel(enum.Enum): diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 23471392..8b34e0c5 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -46,6 +46,11 @@ def __init__(self, ap: app.Application): @property def enable_stages(self): """启用的阶段 + + 默认为消息请求AI前后的两个阶段。 + + entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。 + entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。 """ return [ entities.EnableStage.PRE, @@ -60,5 +65,14 @@ async def initialize(self): @abc.abstractmethod async def process(self, message: str) -> entities.FilterResult: """处理消息 + + 分为前后阶段,具体取决于 enable_stages 的值。 + 对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。 + + Args: + message (str): 需要检查的内容 + + Returns: + entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 """ raise NotImplementedError diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 296c5b4c..5d7e24fb 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -15,6 +15,15 @@ def strategy_class( name: str ) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: + """长文本处理策略类装饰器 + + Args: + name (str): 策略名称 + + Returns: + typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器 + """ + def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]: assert issubclass(cls, LongTextStrategy) @@ -43,4 +52,15 @@ async def initialize(self): @abc.abstractmethod async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + """处理长文本 + + 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 + + Args: + message (str): 消息 + query (core_entities.Query): 此次请求的上下文对象 + + Returns: + list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表 + """ return [] diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index 448ae384..af4def16 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -18,6 +18,7 @@ def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]: class ReteLimitAlgo(metaclass=abc.ABCMeta): + """限流算法抽象类""" name: str = None @@ -31,9 +32,27 @@ async def initialize(self): @abc.abstractmethod async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + """进入处理流程 + + 这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。 + + Args: + launcher_type (str): 请求者类型 群聊为 group 私聊为 person + launcher_id (int): 请求者ID + + Returns: + bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求 + """ raise NotImplementedError @abc.abstractmethod async def release_access(self, launcher_type: str, launcher_id: int): + """退出处理流程 + + Args: + launcher_type (str): 请求者类型 群聊为 group 私聊为 person + launcher_id (int): 请求者ID + """ + raise NotImplementedError \ No newline at end of file diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 5ce1db18..4b159b79 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -14,6 +14,14 @@ def adapter_class( name: str ): + """消息平台适配器类装饰器 + + Args: + name (str): 适配器名称 + + Returns: + typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器 + """ def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]: cls.name = name preregistered_adapters.append(cls) @@ -27,12 +35,19 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta): name: str bot_account_id: int + """机器人账号ID,需要在初始化时设置""" config: dict ap: app.Application def __init__(self, config: dict, ap: app.Application): + """初始化适配器 + + Args: + config (dict): 对应的配置 + ap (app.Application): 应用上下文 + """ self.config = config self.ap = ap diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 0dffc636..8c0c76bc 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -23,14 +23,19 @@ class Message(pydantic.BaseModel): """消息""" role: str # user, system, assistant, tool, command, plugin + """消息的角色""" name: typing.Optional[str] = None + """名称,仅函数调用返回时设置""" content: typing.Optional[str] = None + """内容""" function_call: typing.Optional[FunctionCall] = None + """函数调用,不再受支持,请使用tool_calls""" tool_calls: typing.Optional[list[ToolCall]] = None + """工具调用""" tool_call_id: typing.Optional[str] = None diff --git a/pkg/provider/modelmgr/api.py b/pkg/provider/modelmgr/api.py index da362468..63021bed 100644 --- a/pkg/provider/modelmgr/api.py +++ b/pkg/provider/modelmgr/api.py @@ -38,6 +38,15 @@ async def request( self, query: core_entities.Query, ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求 + """请求API + + 对话前文可以从 query 对象中获取。 + 可以多次yield消息对象。 + + Args: + query (core_entities.Query): 本次请求的上下文对象 + + Yields: + pkg.provider.entities.Message: 返回消息对象 """ raise NotImplementedError diff --git a/pkg/provider/sysprompt/entities.py b/pkg/provider/sysprompt/entities.py index 31ca199a..326ea787 100644 --- a/pkg/provider/sysprompt/entities.py +++ b/pkg/provider/sysprompt/entities.py @@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel): """供AI使用的Prompt""" name: str + """名称""" messages: list[entities.Message] + """消息列表""" diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py index 9e0a6144..855728e2 100644 --- a/pkg/provider/sysprompt/loader.py +++ b/pkg/provider/sysprompt/loader.py @@ -36,7 +36,7 @@ async def initialize(self): @abc.abstractmethod async def load(self): - """加载Prompt + """加载Prompt,存放到prompts列表中 """ raise NotImplementedError From 80258e9182fccbf25dd1e180177a9c45f9006f96 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 22 Mar 2024 17:09:43 +0800 Subject: [PATCH 8/8] =?UTF-8?q?perf:=20=E4=BF=AE=E6=94=B9platform=5Fmgr?= =?UTF-8?q?=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 4 ++-- pkg/core/stages/build_app.py | 2 +- pkg/pipeline/controller.py | 2 +- pkg/pipeline/respback/respback.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/core/app.py b/pkg/core/app.py index 7c9868bf..1ed53042 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -21,7 +21,7 @@ class Application: """运行时应用对象和上下文""" - im_mgr: im_mgr.PlatformManager = None + platform_mgr: im_mgr.PlatformManager = None cmd_mgr: cmdmgr.CommandManager = None @@ -87,7 +87,7 @@ async def run(self): try: tasks = [ - asyncio.create_task(self.im_mgr.run()), + asyncio.create_task(self.platform_mgr.run()), asyncio.create_task(self.ctrl.run()) ] diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 22c5c186..d5365826 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -86,7 +86,7 @@ async def run(self, ap: app.Application): im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() - ap.im_mgr = im_mgr_inst + ap.platform_mgr = im_mgr_inst stage_mgr = stagemgr.StageManager(ap) await stage_mgr.initialize() diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 645014dc..5fe4167b 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -68,7 +68,7 @@ async def _check_output(self, query: entities.Query, result: pipeline_entities.S """检查输出 """ if result.user_notice: - await self.ap.im_mgr.send( + await self.ap.platform_mgr.send( query.message_event, result.user_notice, query.adapter diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 10b2cbac..36a73291 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -29,7 +29,7 @@ async def process(self, query: core_entities.Query, stage_inst_name: str) -> ent await asyncio.sleep(random_delay) - await self.ap.im_mgr.send( + await self.ap.platform_mgr.send( query.message_event, query.resp_message_chain, adapter=query.adapter