diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 55e4efcb..d1973eb3 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -11,5 +11,4 @@ updates: interval: "weekly" allow: - dependency-name: "yiri-mirai-rc" - - dependency-name: "dulwich" - dependency-name: "openai" diff --git a/.github/workflows/build_docker_image.yml b/.github/workflows/build-docker-image.yml similarity index 100% rename from .github/workflows/build_docker_image.yml rename to .github/workflows/build-docker-image.yml diff --git a/.github/workflows/update-cmdpriv-template.yml b/.github/workflows/update-cmdpriv-template.yml deleted file mode 100644 index 7493f332..00000000 --- a/.github/workflows/update-cmdpriv-template.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Update cmdpriv-template - -on: - push: - paths: - - 'pkg/qqbot/cmds/**' - pull_request: - types: [closed] - paths: - - 'pkg/qqbot/cmds/**' - -jobs: - update-cmdpriv-template: - if: github.event.pull_request.merged == true || github.event_name == 'push' - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.10.13 - - - name: Install dependencies - run: | - python -m pip install --upgrade yiri-mirai-rc openai>=1.0.0 colorlog func_timeout dulwich Pillow CallingGPT tiktoken - python -m pip install -U openai>=1.0.0 - - - name: Copy Scripts - run: | - cp res/scripts/generate_cmdpriv_template.py . - - - name: Generate Files - run: | - python main.py - - - name: Run generate_cmdpriv_template.py - run: python3 generate_cmdpriv_template.py - - - name: Check for changes in cmdpriv-template.json - id: check_changes - run: | - if git diff --name-only | grep -q "res/templates/cmdpriv-template.json"; then - echo "::set-output name=changes_detected::true" - else - echo "::set-output name=changes_detected::false" - fi - - - name: Commit changes to cmdpriv-template.json - if: steps.check_changes.outputs.changes_detected == 'true' - run: | - git config --global user.name "GitHub Actions Bot" - git config --global user.email "" - git add res/templates/cmdpriv-template.json - git commit -m "Update cmdpriv-template.json" - git push diff --git a/.github/workflows/update-override-all.yml b/.github/workflows/update-override-all.yml deleted file mode 100644 index 83ef6a6b..00000000 --- a/.github/workflows/update-override-all.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Check and Update override_all - -on: - push: - paths: - - 'config-template.py' - pull_request: - types: - - closed - branches: - - master - paths: - - 'config-template.py' - -jobs: - update-override-all: - name: check and update - if: github.event.pull_request.merged == true || github.event_name == 'push' - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.x - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - - - name: Copy Scripts - run: | - cp res/scripts/generate_override_all.py . - - - name: Run generate_override_all.py - run: python3 generate_override_all.py - - - name: Check for changes in override-all.json - id: check_changes - run: | - git diff --exit-code override-all.json || echo "::set-output name=changes_detected::true" - - - name: Commit and push changes - if: steps.check_changes.outputs.changes_detected == 'true' - run: | - git config --global user.email "github-actions[bot]@users.noreply.github.com" - git config --global user.name "GitHub Actions" - git add override-all.json - git commit -m "Update override-all.json" - git push diff --git a/.gitignore b/.gitignore index b9069a7f..2af326f5 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ bard.json !/docker-compose.yaml res/instance_id.json .DS_Store +/data +botpy.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index d977b91a..1dfd0058 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,8 @@ -FROM python:3.10.13-bullseye -WORKDIR /QChatGPT +FROM python:3.10.13-slim +WORKDIR /app -COPY . /QChatGPT/ +COPY . . -RUN ls - -RUN python -m pip install -r requirements.txt && \ - python -m pip install -U websockets==10.0 && \ - python -m pip install -U httpcore httpx openai - -# 生成配置文件 -RUN python main.py +RUN python -m pip install -r requirements.txt CMD [ "python", "main.py" ] \ No newline at end of file diff --git a/README_en.md b/README_en.md deleted file mode 100644 index dbd9e6a1..00000000 --- a/README_en.md +++ /dev/null @@ -1,215 +0,0 @@ -# QChatGPT🤖 - -

-QChatGPT -

- -English | [简体中文](README.md) - -[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT?style=flat-square)](https://github.com/RockChinQ/QChatGPT/releases/latest) -![Wakapi Count](https://wakapi.dev/api/badge/RockChinQ/interval:any/project:QChatGPT) - -- Refer to [Wiki](https://github.com/RockChinQ/QChatGPT/wiki) to get further information. -- Official QQ group: 656285629 -- Community QQ group: 362515018 -- QQ channel robot: [QQChannelChatGPT](https://github.com/Soulter/QQChannelChatGPT) -- Any contribution is welcome, please refer to [CONTRIBUTING.md](CONTRIBUTING.md) - -## 🍺List of supported models - -
-Details - -### Chat - -- OpenAI GPT-3.5 (ChatGPT API), default model -- OpenAI GPT-3, supported natively, switch to it in `config.py` -- OpenAI GPT-4, supported natively, qualification for internal testing required, switch to it in `config.py` -- ChatGPT website edition (GPT-3.5), see [revLibs plugin](https://github.com/RockChinQ/revLibs) -- ChatGPT website edition (GPT-4), ChatGPT plus subscription required, see [revLibs plugin](https://github.com/RockChinQ/revLibs) -- New Bing, see [revLibs plugin](https://github.com/RockChinQ/revLibs) -- HuggingChat, see [revLibs plugin](https://github.com/RockChinQ/revLibs), English only - -### Story - -- NovelAI API, see [QCPNovelAi plugin](https://github.com/dominoar/QCPNovelAi) - -### Image - -- OpenAI DALL·E, supported natively, see [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE) -- NovelAI API, see [QCPNovelAi plugin](https://github.com/dominoar/QCPNovelAi) - -### Voice - -- TTS+VITS, see [QChatPlugins](https://github.com/dominoar/QChatPlugins) -- Plachta/VITS-Umamusume-voice-synthesizer, see [chat_voice plugin](https://github.com/oliverkirk-sudo/chat_voice) - - -
- -Install this [plugin](https://github.com/RockChinQ/Switcher) to switch between different models. - -## ✅Features - -
-Details - - - ✅Sensitive word filtering, avoid being banned - - ✅Multiple responding rules, including regular expression matching - - ✅Multiple api-key management, automatic switching when exceeding - - ✅Support for customizing the preset prompt text - - ✅Chat, story, image, voice, etc. models are supported - - ✅Support for hot reloading and hot updating - - ✅Support for plugin loading - - ✅Blacklist mechanism for private chat and group chat - - ✅Excellent long message processing strategy - - ✅Reply rate limitation - - ✅Support for network proxy - - ✅Support for customizing the output format -
- -More details, see [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE) - -## 🔩Deployment - -**If you encounter any problems during deployment, please search in the issue of [QChatGPT](https://github.com/RockChinQ/QChatGPT/issues) or [qcg-installer](https://github.com/RockChinQ/qcg-installer/issues) first.** - -### - Register OpenAI account - -> If you want to use a model other than OpenAI (such as New Bing), you can skip this step and directly refer to following steps, and then configure it according to the relevant plugin documentation. - -To register OpenAI account, please refer to the following articles(in Chinese): - -> [国内注册ChatGPT的方法(100%可用)](https://www.pythonthree.com/register-openai-chatgpt/) -> [手把手教你如何注册ChatGPT,超级详细](https://guxiaobei.com/51461) - -Check your api-key in [personal center](https://beta.openai.com/account/api-keys) after registration, and then follow the following steps to deploy. - -### - Deploy Automatically - -
-Details - -#### Docker - -See [this document(cn)](res/docs/docker_deploy.md) -Contributed by [@mikumifa](https://github.com/mikumifa) - -#### Installer - -Use [this installer](https://github.com/RockChinQ/qcg-installer) to deploy. - -- The installer currently only supports some platforms, please refer to the repository document for details, and manually deploy for other platforms - -
- -### - Deploy Manually -
-Manually deployment supports any platforms - -- Python 3.9.x or higher - -#### 配置QQ登录框架 - -Currently supports mirai and go-cqhttp, configure either one - -
-mirai - -Follow [this tutorial(cn)](https://yiri-mirai.wybxc.cc/tutorials/01/configuration) to configure Mirai and YiriMirai. -After starting mirai-console, use the `login` command to log in to the QQ account, and keep the mirai-console running. - -
- -
-go-cqhttp - -1. Follow [this tutorial(cn)](https://github.com/RockChinQ/QChatGPT/wiki/go-cqhttp%E9%85%8D%E7%BD%AE) to configure go-cqhttp. -2. Start go-cqhttp, make sure it is logged in and running. - -
- -#### Configure QChatGPT - -1. Clone the repository - -```bash -git clone https://github.com/RockChinQ/QChatGPT -cd QChatGPT -``` - -2. Install dependencies - -```bash -pip3 install requests yiri-mirai-rc openai colorlog func_timeout dulwich Pillow nakuru-project-idk -``` - -3. Generate `config.py` - -```bash -python3 main.py -``` - -4. Edit `config.py` - -5. Run - -```bash -python3 main.py -``` - -Any problems, please refer to the issues page. - -
- -## 🚀Usage - -**After deployment, please read: [Commands(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E6%9C%BA%E5%99%A8%E4%BA%BA%E6%8C%87%E4%BB%A4)** - -**For more details, please refer to the [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F)** - - -## 🧩Plugin Ecosystem - -Plugin [usage](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) and [development](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91) are supported. - -
-List of plugins (cn) - -### Examples - -在`tests/plugin_examples`目录下,将其整个目录复制到`plugins`目录下即可使用 - -- `cmdcn` - 主程序命令中文形式 -- `hello_plugin` - 在收到消息`hello`时回复相应消息 -- `urlikethisijustsix` - 收到冒犯性消息时回复相应消息 - -### More Plugins - -欢迎提交新的插件 - -- [revLibs](https://github.com/RockChinQ/revLibs) - 将ChatGPT网页版接入此项目,关于[官方接口和网页版有什么区别](https://github.com/RockChinQ/QChatGPT/wiki/%E5%AE%98%E6%96%B9%E6%8E%A5%E5%8F%A3%E4%B8%8EChatGPT%E7%BD%91%E9%A1%B5%E7%89%88) -- [Switcher](https://github.com/RockChinQ/Switcher) - 支持通过命令切换使用的模型 -- [hello_plugin](https://github.com/RockChinQ/hello_plugin) - `hello_plugin` 的储存库形式,插件开发模板 -- [dominoar/QChatPlugins](https://github.com/dominoar/QchatPlugins) - dominoar编写的诸多新功能插件(语音输出、Ranimg、屏蔽词规则等) -- [dominoar/QCP-NovelAi](https://github.com/dominoar/QCP-NovelAi) - NovelAI 故事叙述与绘画 -- [oliverkirk-sudo/chat_voice](https://github.com/oliverkirk-sudo/chat_voice) - 文字转语音输出,使用HuggingFace上的[VITS-Umamusume-voice-synthesizer模型](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer) -- [RockChinQ/WaitYiYan](https://github.com/RockChinQ/WaitYiYan) - 实时获取百度`文心一言`等待列表人数 -- [chordfish-k/QChartGPT_Emoticon_Plugin](https://github.com/chordfish-k/QChartGPT_Emoticon_Plugin) - 使机器人根据回复内容发送表情包 -- [oliverkirk-sudo/ChatPoeBot](https://github.com/oliverkirk-sudo/ChatPoeBot) - 接入[Poe](https://poe.com/)上的机器人 -- [lieyanqzu/WeatherPlugin](https://github.com/lieyanqzu/WeatherPlugin) - 天气查询插件 -
- -## 😘Thanks - -- [@the-lazy-me](https://github.com/the-lazy-me) video tutorial creator -- [@mikumifa](https://github.com/mikumifa) Docker deployment -- [@dominoar](https://github.com/dominoar) Plugin development -- [@万神的星空](https://github.com/qq255204159) Packages publisher -- [@ljcduo](https://github.com/ljcduo) GPT-4 API internal test account - -And all [contributors](https://github.com/RockChinQ/QChatGPT/graphs/contributors) and other friends who support this project. - - diff --git a/config-template.py b/config-template.py deleted file mode 100644 index fb60ff1e..00000000 --- a/config-template.py +++ /dev/null @@ -1,370 +0,0 @@ -# 配置文件: 注释里标[必需]的参数必须修改, 其他参数根据需要修改, 但请勿删除 -import logging - -# 消息处理协议适配器 -# 目前支持以下适配器: -# - "yirimirai": mirai的通信框架,YiriMirai框架适配器, 请同时填写下方mirai_http_api_config -# - "nakuru": go-cqhttp通信框架,请同时填写下方nakuru_config -msg_source_adapter = "yirimirai" - -# [必需(与nakuru二选一,取决于msg_source_adapter)] Mirai的配置 -# 请到配置mirai的步骤中的教程查看每个字段的信息 -# adapter: 选择适配器,目前支持HTTPAdapter和WebSocketAdapter -# host: 运行mirai的主机地址 -# port: 运行mirai的主机端口 -# verifyKey: mirai-api-http的verifyKey -# qq: 机器人的QQ号 -# -# 注意: QQ机器人配置不支持热重载及热更新 -mirai_http_api_config = { - "adapter": "WebSocketAdapter", - "host": "localhost", - "port": 8080, - "verifyKey": "yirimirai", - "qq": 1234567890 -} - -# [必需(与mirai二选一,取决于msg_source_adapter)] -# 使用nakuru-project框架连接go-cqhttp的配置 -nakuru_config = { - "host": "localhost", # go-cqhttp的地址 - "port": 6700, # go-cqhttp的正向websocket端口 - "http_port": 5700, # go-cqhttp的正向http端口 - "token": "" # 若在go-cqhttp的config.yml设置了access_token, 则填写此处 -} - -# [必需] OpenAI的配置 -# api_key: OpenAI的API Key -# http_proxy: 请求OpenAI时使用的代理,None为不使用,https和socks5暂不能使用 -# 若只有一个api-key,请直接修改以下内容中的"openai_api_key"为你的api-key -# -# 如准备了多个api-key,可以以字典的形式填写,程序会自动选择可用的api-key -# 例如 -# openai_config = { -# "api_key": { -# "default": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# }, -# "http_proxy": "http://127.0.0.1:12345" -# } -# -# 现已支持反向代理,可以添加reverse_proxy字段以使用反向代理 -# 使用反向代理可以在国内使用OpenAI的API,反向代理的配置请参考 -# https://github.com/Ice-Hazymoon/openai-scf-proxy -# -# 反向代理填写示例: -# openai_config = { -# "api_key": { -# "default": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -# }, -# "reverse_proxy": "http://example.com:12345/v1" -# } -# -# 作者开设公用反向代理地址: https://api.openai.rockchin.top/v1 -# 随时可能关闭,仅供测试使用,有条件建议使用正向代理或者自建反向代理 -openai_config = { - "api_key": { - "default": "openai_api_key" - }, - "http_proxy": None, - "reverse_proxy": None -} - -# api-key切换策略 -# active:每次请求时都会切换api-key -# passive:仅当api-key超额时才会切换api-key -switch_strategy = "active" - -# [必需] 管理员QQ号,用于接收报错等通知及执行管理员级别命令 -# 支持多个管理员,可以使用list形式设置,例如: -# admin_qq = [12345678, 87654321] -admin_qq = 0 - -# 情景预设(机器人人格) -# 每个会话的预设信息,影响所有会话,无视命令重置 -# 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令 -# 例如: -# default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”" -# 这样用户在不知所措的时候机器人就会提示其输入!help获取帮助 -# 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh -# -# 如果需要多个情景预设,并在运行期间方便切换,请使用字典的形式填写,例如 -# default_prompt = { -# "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”", -# "linux-terminal": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。", -# "en-dict": "我想让你充当英英词典,对于给出的英文单词,你要给出其中文意思以及英文解释,并且给出一个例句,此外不要有其他反馈。", -# } -# -# 在使用期间即可通过命令: -# !reset [名称] -# 来使用指定的情景预设重置会话 -# 例如: -# !reset linux-terminal -# 若不指定名称,则使用默认情景预设 -# -# 也可以使用命令: -# !default <名称> -# 将指定的情景预设设置为默认情景预设 -# 例如: -# !default linux-terminal -# 之后的会话重置时若不指定名称,则使用linux-terminal情景预设 -# -# 还可以加载文件中的预设文字,使用方法请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E9%A2%84%E8%AE%BE%E6%96%87%E5%AD%97 -default_prompt = { - "default": "如果用户之后想获取帮助,请你说“输入!help获取帮助”。", -} - -# 情景预设格式 -# 参考值:默认方式:normal | 完整情景:full_scenario -# 默认方式 的格式为上述default_prompt中的内容,或prompts目录下的文件名 -# 完整情景方式 的格式为JSON,在scenario目录下的JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json -# 编写方法请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E9%A2%84%E8%AE%BE%E6%96%87%E5%AD%97full_scenario%E6%A8%A1%E5%BC%8F -preset_mode = "normal" - -# 群内响应规则 -# 符合此消息的群内消息即使不包含at机器人也会响应 -# 支持消息前缀匹配及正则表达式匹配 -# 支持设置是否响应at消息、随机响应概率 -# 注意:由消息前缀(prefix)匹配的消息中将会删除此前缀,正则表达式(regexp)匹配的消息不会删除匹配的部分 -# 前缀匹配优先级高于正则表达式匹配 -# 正则表达式简明教程:https://www.runoob.com/regexp/regexp-tutorial.html -# -# 支持针对不同群设置不同的响应规则,例如: -# response_rules = { -# "default": { -# "at": True, -# "prefix": ["/ai", "!ai", "!ai", "ai"], -# "regexp": [], -# "random_rate": 0.0, -# }, -# "12345678": { -# "at": False, -# "prefix": ["/ai", "!ai", "!ai", "ai"], -# "regexp": [], -# "random_rate": 0.0, -# }, -# } -# -# 以上设置将会在群号为12345678的群中关闭at响应 -# 未单独设置的群将使用default规则 -response_rules = { - "default": { - "at": True, # 是否响应at机器人的消息 - "prefix": ["/ai", "!ai", "!ai", "ai"], - "regexp": [], # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办" - "random_rate": 0.0, # 随机响应概率,0.0-1.0,0.0为不随机响应,1.0为响应所有消息, 仅在前几项判断不通过时生效 - }, -} - - -# 消息忽略规则 -# 适用于私聊及群聊 -# 符合此规则的消息将不会被响应 -# 支持消息前缀匹配及正则表达式匹配 -# 此设置优先级高于response_rules -# 用以过滤mirai等其他层级的命令 -# @see https://github.com/RockChinQ/QChatGPT/issues/165 -ignore_rules = { - "prefix": ["/"], - "regexp": [] -} - -# 是否检查收到的消息中是否包含敏感词 -# 若收到的消息无法通过下方指定的敏感词检查策略,则发送提示信息 -income_msg_check = False - -# 敏感词过滤开关,以同样数量的*代替敏感词回复 -# 请在sensitive.json中添加敏感词 -sensitive_word_filter = True - -# 是否启用百度云内容安全审核 -# 注册方式查看 https://cloud.baidu.com/doc/ANTIPORN/s/Wkhu9d5iy -baidu_check = False - -# 百度云API_KEY 24位英文数字字符串 -baidu_api_key = "" - -# 百度云SECRET_KEY 32位的英文数字字符串 -baidu_secret_key = "" - -# 不合规消息自定义返回 -inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规" - -# 启动时是否发送赞赏码 -# 仅当使用量已经超过2048字时发送 -encourage_sponsor_at_start = True - -# 每次向OpenAI接口发送对话记录上下文的字符数 -# 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens -# 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快 -prompt_submit_length = 3072 - -# 是否在token超限报错时自动重置会话 -# 可在tips.py中编辑提示语 -auto_reset = True - -# OpenAI补全API的参数 -# 请在下方填写模型,程序自动选择接口 -# 模型文档:https://platform.openai.com/docs/models -# 现已支持的模型有: -# -# ChatCompletions 接口: -# # GPT 4 系列 -# "gpt-4-1106-preview", -# "gpt-4-vision-preview", -# "gpt-4", -# "gpt-4-32k", -# "gpt-4-0613", -# "gpt-4-32k-0613", -# "gpt-4-0314", # legacy -# "gpt-4-32k-0314", # legacy -# # GPT 3.5 系列 -# "gpt-3.5-turbo-1106", -# "gpt-3.5-turbo", -# "gpt-3.5-turbo-16k", -# "gpt-3.5-turbo-0613", # legacy -# "gpt-3.5-turbo-16k-0613", # legacy -# "gpt-3.5-turbo-0301", # legacy -# -# Completions接口: -# "gpt-3.5-turbo-instruct", -# -# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/completions/create -# 请将内容修改到config.py中,请勿修改config-template.py -# -# 支持通过 One API 接入多种模型,请在上方的openai_config中设置One API的代理地址, -# 并在此填写您要使用的模型名称,详细请参考:https://github.com/songquanpeng/one-api -# -# 支持的 One API 模型: -# "SparkDesk", -# "chatglm_pro", -# "chatglm_std", -# "chatglm_lite", -# "qwen-v1", -# "qwen-plus-v1", -# "ERNIE-Bot", -# "ERNIE-Bot-turbo", -# "gemini-pro", -completion_api_params = { - "model": "gpt-3.5-turbo", - "temperature": 0.9, # 数值越低得到的回答越理性,取值范围[0, 1] -} - -# OpenAI的Image API的参数 -# 具体请查看OpenAI的文档: https://platform.openai.com/docs/api-reference/images/create -image_api_params = { - "model": "dall-e-2", # 默认使用 dall-e-2 模型,也可以改为 dall-e-3 - # 图片尺寸 - # dall-e-2 模型支持 256x256, 512x512, 1024x1024 - # dall-e-3 模型支持 1024x1024, 1792x1024, 1024x1792 - "size": "256x256", -} - -# 跟踪函数调用 -# 为True时,在每次GPT进行Function Calling时都会输出发送一条回复给用户 -# 同时,一次提问内所有的Function Calling和普通回复消息都会单独发送给用户 -trace_function_calls = False - -# 群内回复消息时是否引用原消息 -quote_origin = False - -# 群内回复消息时是否at发送者 -at_sender = False - -# 回复绘图时是否包含图片描述 -include_image_description = True - -# 消息处理的超时时间,单位为秒 -process_message_timeout = 120 - -# 回复消息时是否显示[GPT]前缀 -show_prefix = False - -# 回复前的强制延迟时间,降低机器人被腾讯风控概率 -# *此机制对命令和消息、私聊及群聊均生效 -# 每次处理时从以下的范围取一个随机秒数, -# 当此次消息处理时间低于此秒数时,将会强制延迟至此秒数 -# 例如:[1.5, 3],则每次处理时会随机取一个1.5-3秒的随机数,若处理时间低于此随机数,则强制延迟至此随机秒数 -# 若您不需要此功能,请将force_delay_range设置为[0, 0] -force_delay_range = [0, 0] - -# 应用长消息处理策略的阈值 -# 当回复消息长度超过此值时,将使用长消息处理策略 -blob_message_threshold = 256 - -# 长消息处理策略 -# - "image": 将长消息转换为图片发送 -# - "forward": 将长消息转换为转发消息组件发送 -blob_message_strategy = "forward" - -# 允许等待 -# 同一会话内,是否等待上一条消息处理完成后再处理下一条消息 -# 若设置为False,若上一条未处理完时收到了新消息,将会丢弃新消息 -# 丢弃消息时的提示信息可以在tips.py中修改 -wait_last_done = True - -# 文字转图片时使用的字体文件路径 -# 当策略为"image"时生效 -# 若在Windows系统下,程序会自动使用Windows自带的微软雅黑字体 -# 若未填写或不存在且不是Windows,将禁用文字转图片功能,改为使用转发消息组件 -font_path = "" - -# 消息处理超时重试次数 -retry_times = 3 - -# 消息处理出错时是否向用户隐藏错误详细信息 -# 设置为True时,仅向管理员发送错误详细信息 -# 设置为False时,向用户及管理员发送错误详细信息 -hide_exce_info_to_user = False - -# 每个会话的过期时间,单位为秒 -# 默认值20分钟 -session_expire_time = 1200 - -# 会话限速 -# 单会话内每分钟可进行的对话次数 -# 若不需要限速,可以设置为一个很大的值 -# 默认值60次,基本上不会触发限速 -# -# 若要设置针对某特定群的限速,请使用如下格式: -# { -# "group_<群号>": 60, -# "default": 60, -# } -# 若要设置针对某特定用户私聊的限速,请使用如下格式: -# { -# "person_<用户QQ>": 60, -# "default": 60, -# } -# 同时设置多个群和私聊的限速,示例: -# { -# "group_12345678": 60, -# "group_87654321": 60, -# "person_234567890": 60, -# "person_345678901": 60, -# "default": 60, -# } -# -# 注意: 未指定的都使用default的限速值,default不可删除 -rate_limitation = { - "default": 60, -} - -# 会话限速策略 -# - "wait": 每次对话获取到回复时,等待一定时间再发送回复,保证其不会超过限速均值 -# - "drop": 此分钟内,若对话次数超过限速次数,则丢弃之后的对话,每自然分钟重置 -rate_limit_strategy = "drop" - -# 是否在启动时进行依赖库更新 -upgrade_dependencies = False - -# 是否上报统计信息 -# 用于统计机器人的使用情况,数据不公开,不会收集任何敏感信息。 -# 仅实例识别UUID、上报时间、字数使用量、绘图使用量、插件使用情况、用户信息,其他信息不会上报 -report_usage = True - -# 日志级别 -logging_level = logging.INFO diff --git a/docker-compose.yaml b/docker-compose.yaml index f2dc6887..bd6067cd 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,15 +4,7 @@ services: qchatgpt: image: rockchin/qchatgpt:latest volumes: - - ./config.py:/QChatGPT/config.py - - ./banlist.py:/QChatGPT/banlist.py - - ./cmdpriv.json:/QChatGPT/cmdpriv.json - - ./sensitive.json:/QChatGPT/sensitive.json - - ./tips.py:/QChatGPT/tips.py - # 目录映射 - - ./plugins:/QChatGPT/plugins - - ./scenario:/QChatGPT/scenario - - ./temp:/QChatGPT/temp - - ./logs:/QChatGPT/logs - restart: always + - ./data:/app/data + - ./plugins:/app/plugins + restart: on-failure # 根据具体环境配置网络 \ No newline at end of file diff --git a/main.py b/main.py index 99b89035..4f3373df 100644 --- a/main.py +++ b/main.py @@ -1,496 +1,18 @@ -import importlib -import json -import os -import shutil -import threading -import time - -import logging -import sys -import traceback import asyncio -sys.path.append(".") - - -def check_file(): - # 检查是否有banlist.py,如果没有就把banlist-template.py复制一份 - if not os.path.exists('banlist.py'): - shutil.copy('res/templates/banlist-template.py', 'banlist.py') - - # 检查是否有sensitive.json - if not os.path.exists("sensitive.json"): - shutil.copy("res/templates/sensitive-template.json", "sensitive.json") - - # 检查是否有scenario/default.json - if not os.path.exists("scenario/default.json"): - shutil.copy("scenario/default-template.json", "scenario/default.json") - - # 检查cmdpriv.json - if not os.path.exists("cmdpriv.json"): - shutil.copy("res/templates/cmdpriv-template.json", "cmdpriv.json") - - # 检查tips_custom - if not os.path.exists("tips.py"): - shutil.copy("tips-custom-template.py", "tips.py") - - # 检查temp目录 - if not os.path.exists("temp/"): - os.mkdir("temp/") - - # 检查并创建plugins、prompts目录 - check_path = ["plugins", "prompts"] - for path in check_path: - if not os.path.exists(path): - os.mkdir(path) - - # 配置文件存在性校验 - if not os.path.exists('config.py'): - shutil.copy('config-template.py', 'config.py') - print('请先在config.py中填写配置') - sys.exit(0) - - -# 初始化相关文件 -check_file() - -from pkg.utils.log import init_runtime_log_file, reset_logging -from pkg.config import manager as config_mgr -from pkg.config.impls import pymodule as pymodule_cfg - - -try: - import colorlog -except ImportError: - # 尝试安装 - import pkg.utils.pkgmgr as pkgmgr - try: - pkgmgr.install_requirements("requirements.txt") - import colorlog - except ImportError: - print("依赖不满足,请查看 https://github.com/RockChinQ/qcg-installer/issues/15") - sys.exit(1) -import colorlog - -import requests -import websockets.exceptions -from urllib3.exceptions import InsecureRequestWarning -import pkg.utils.context - - -# 是否使用override.json覆盖配置 -# 仅在启动时提供 --override 或 -r 参数时生效 -use_override = False - - -def init_db(): - import pkg.database.manager - database = pkg.database.manager.DatabaseManager() - - database.initialize_database() - - -def ensure_dependencies(): - import pkg.utils.pkgmgr as pkgmgr - pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade", - "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", - "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - - -known_exception_caught = False - - -def override_config_manager(): - config = pkg.utils.context.get_config_manager().data - - if os.path.exists("override.json") and use_override: - override_json = json.load(open("override.json", "r", encoding="utf-8")) - overrided = [] - for key in override_json: - if key in config: - config[key] = override_json[key] - # logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) - overrided.append(key) - else: - logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) - if len(overrided) > 0: - logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) - - -def complete_tips(): - """根据tips-custom-template模块补全tips模块的属性""" - non_exist_keys = [] - - is_integrity = True - logging.debug("检查tips模块完整性.") - tips_template = importlib.import_module('tips-custom-template') - tips = importlib.import_module('tips') - for key in dir(tips_template): - if not key.startswith("__") and not hasattr(tips, key): - setattr(tips, key, getattr(tips_template, key)) - # logging.warning("[{}]不存在".format(key)) - non_exist_keys.append(key) - is_integrity = False - - if not is_integrity: - logging.warning("以下提示语字段不存在: {}".format(", ".join(non_exist_keys))) - logging.warning("tips模块不完整,您可以依据tips-custom-template.py检查tips.py") - logging.warning("以上配置已被设为默认值,将在3秒后继续启动... ") - time.sleep(3) - - -async def start_process(first_time_init=False): - """启动流程,reload之后会被执行""" - - global known_exception_caught - import pkg.utils.context - - # 计算host和instance标识符 - import pkg.audit.identifier - pkg.audit.identifier.init() - - # 加载配置 - cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile( - 'config.py', - 'config-template.py' - ) - await config_mgr.ConfigManager(cfg_inst).load_config() - - override_config_manager() - - # 检查tips模块 - complete_tips() - - cfg = pkg.utils.context.get_config_manager().data - - # 更新openai库到最新版本 - if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']: - print("正在更新依赖库,请等待...") - if 'upgrade_dependencies' not in cfg: - print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False") - else: - print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False") - try: - ensure_dependencies() - except Exception as e: - print("更新openai库失败:{}, 请忽略或自行更新".format(e)) - - known_exception_caught = False - try: - try: - - sh = reset_logging() - pkg.utils.context.context['logger_handler'] = sh - - # 初始化文字转图片 - from pkg.utils import text2img - text2img.initialize() - - # 检查是否设置了管理员 - if cfg['admin_qq'] == 0: - # logging.warning("未设置管理员QQ,管理员权限命令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") - while True: - try: - cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限命令及运行告警将无法使用,请输入管理员QQ号: ")) - # 写入到文件 - - # 读取文件 - config_file_str = "" - with open("config.py", "r", encoding="utf-8") as f: - config_file_str = f.read() - # 替换 - config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq'])) - # 写入 - with open("config.py", "w", encoding="utf-8") as f: - f.write(config_file_str) - - print("管理员QQ已设置,如需修改请修改config.py中的admin_qq字段") - time.sleep(4) - break - except ValueError: - print("请输入数字") - - # 初始化中央服务器 API 交互实例 - from pkg.utils.center import apigroup - from pkg.utils.center import v2 as center_v2 - - center_v2_api = center_v2.V2CenterAPI( - basic_info={ - "host_id": pkg.audit.identifier.identifier['host_id'], - "instance_id": pkg.audit.identifier.identifier['instance_id'], - "semantic_version": pkg.utils.updater.get_current_tag(), - "platform": sys.platform, - }, - runtime_info={ - "admin_id": "{}".format(cfg['admin_qq']), - "msg_source": cfg['msg_source_adapter'], - } - ) - pkg.utils.context.set_center_v2_api(center_v2_api) - - import pkg.openai.manager - import pkg.database.manager - import pkg.openai.session - import pkg.qqbot.manager - import pkg.openai.dprompt - import pkg.qqbot.cmds.aamgr - - try: - pkg.openai.dprompt.register_all() - pkg.qqbot.cmds.aamgr.register_all() - pkg.qqbot.cmds.aamgr.apply_privileges() - except Exception as e: - logging.error(e) - traceback.print_exc() - - # 配置OpenAI proxy - import openai - openai.proxies = None # 先重置,因为重载后可能需要清除proxy - if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None: - openai.proxies = { - "http": cfg['openai_config']["http_proxy"], - "https": cfg['openai_config']["http_proxy"] - } - - # 配置openai api_base - if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None: - logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy']) - openai.base_url = cfg['openai_config']["reverse_proxy"] - - # 主启动流程 - database = pkg.database.manager.DatabaseManager() - - database.initialize_database() - - openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key']) - - # 加载所有未超时的session - pkg.openai.session.load_sessions() - - # 初始化qq机器人 - qqbot = pkg.qqbot.manager.QQBotManager(first_time_init=first_time_init) - - # 加载插件 - import pkg.plugin.host - pkg.plugin.host.load_plugins() - - pkg.plugin.host.initialize_plugins() - - if first_time_init: # 不是热重载之后的启动,则启动新的bot线程 - - import mirai.exceptions - - def run_bot_wrapper(): - global known_exception_caught - try: - logging.debug("使用账号: {}".format(qqbot.bot_account_id)) - qqbot.adapter.run_sync() - except TypeError as e: - if str(e).__contains__("argument 'debug'"): - logging.error( - "连接bot失败:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/82".format(e)) - known_exception_caught = True - elif str(e).__contains__("As of 3.10, the *loop*"): - logging.error( - "Websockets版本过低:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/5".format(e)) - known_exception_caught = True - - except websockets.exceptions.InvalidStatus as e: - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - known_exception_caught = True - except mirai.exceptions.NetworkError as e: - logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e)) - known_exception_caught = True - except Exception as e: - if str(e).__contains__("404"): - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - known_exception_caught = True - elif str(e).__contains__("signal only works in main thread"): - logging.error( - "hypercorn异常:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/86".format( - e)) - known_exception_caught = True - elif str(e).__contains__("did not receive a valid HTTP"): - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - else: - import traceback - traceback.print_exc() - logging.error( - "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e)) - known_exception_caught = True - raise e - finally: - time.sleep(12) - threading.Thread( - target=run_bot_wrapper - ).start() - except Exception as e: - traceback.print_exc() - if isinstance(e, KeyboardInterrupt): - logging.info("程序被用户中止") - sys.exit(0) - elif isinstance(e, SyntaxError): - logging.error("配置文件存在语法错误,请检查配置文件:\n1. 是否存在中文符号\n2. 是否已按照文件中的说明填写正确") - sys.exit(1) - else: - logging.error("初始化失败:{}".format(e)) - sys.exit(1) - finally: - # 判断若是Windows,输出选择模式可能会暂停程序的警告 - if os.name == 'nt': - time.sleep(2) - logging.info("您正在使用Windows系统,若命令行窗口处于“选择”模式,程序可能会被暂停,此时请右键点击窗口空白区域使其取消选择模式。") - - time.sleep(12) - - if first_time_init: - if not known_exception_caught: - if cfg['msg_source_adapter'] == "yirimirai": - logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port']))) - logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): ' - 'https://github.com/RockChinQ/QChatGPT/issues/37') - elif cfg['msg_source_adapter'] == 'nakuru': - logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port'])) - logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确') - else: - sys.exit(1) - else: - logging.info('热重载完成') - - # 发送赞赏码 - if cfg['encourage_sponsor_at_start'] \ - and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048: - - logging.info("发送赞赏码") - from mirai import MessageChain, Plain, Image - import pkg.utils.constants - message_chain = MessageChain([ - Plain("自2022年12月初以来,开发者已经花费了大量时间和精力来维护本项目,如果您觉得本项目对您有帮助,欢迎赞赏开发者," - "以支持项目稳定运行😘"), - Image(base64=pkg.utils.constants.alipay_qr_b64), - Image(base64=pkg.utils.constants.wechat_qr_b64), - Plain("BTC: 3N4Azee63vbBB9boGv9Rjf4N5SocMe5eCq\nXMR: 89LS21EKQuDGkyQoe2nDupiuWXk4TVD6FALvSKv5owfmeJEPFpHeMsZLYtLiJ6GxLrhsRe5gMs6MyMSDn4GNQAse2Mae4KE\n\n"), - Plain("(本消息仅在启动时发送至管理员,如果您不想再看到此消息,请在config.py中将encourage_sponsor_at_start设置为False)") - ]) - pkg.utils.context.get_qqbot_manager().notify_admin_message_chain(message_chain) - - time.sleep(5) - import pkg.utils.updater - try: - if pkg.utils.updater.is_new_version_available(): - logging.info("新版本可用,请发送 !update 进行自动更新\n更新日志:\n{}".format("\n".join(pkg.utils.updater.get_rls_notes()))) - else: - # logging.info("当前已是最新版本") - pass - - except Exception as e: - logging.warning("检查更新失败:{}".format(e)) - - try: - import pkg.utils.announcement as announcement - new_announcement = announcement.fetch_new() - if len(new_announcement) > 0: - for announcement in new_announcement: - logging.critical("[公告]<{}> {}".format(announcement['time'], announcement['content'])) - - # 发送统计数据 - pkg.utils.context.get_center_v2_api().main.post_announcement_showed( - [announcement['id'] for announcement in new_announcement] - ) - - except Exception as e: - logging.warning("获取公告失败:{}".format(e)) - - return qqbot - -def stop(): - import pkg.qqbot.manager - import pkg.openai.session - try: - import pkg.plugin.host - pkg.plugin.host.unload_plugins() - - qqbot_inst = pkg.utils.context.get_qqbot_manager() - assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager) - - for session in pkg.openai.session.sessions: - logging.info('持久化session: %s', session) - pkg.openai.session.sessions[session].persistence() - pkg.utils.context.get_database_manager().close() - except Exception as e: - if not isinstance(e, KeyboardInterrupt): - raise e - - -def main(): - global use_override - # 检查是否携带了 --override 或 -r 参数 - if '--override' in sys.argv or '-r' in sys.argv: - use_override = True - - # 初始化logging - init_runtime_log_file() - pkg.utils.context.context['logger_handler'] = reset_logging() - - # 配置线程池 - from pkg.utils import ThreadCtl - thread_ctl = ThreadCtl( - sys_pool_num=8, - admin_pool_num=4, - user_pool_num=8 - ) - # 存进上下文 - pkg.utils.context.set_thread_ctl(thread_ctl) - - # 启动指令处理 - if len(sys.argv) > 1 and sys.argv[1] == 'init_db': - init_db() - sys.exit(0) - - elif len(sys.argv) > 1 and sys.argv[1] == 'update': - print("正在进行程序更新...") - import pkg.utils.updater as updater - updater.update_all(cli=True) - sys.exit(0) - - # 关闭urllib的http警告 - requests.packages.urllib3.disable_warnings(InsecureRequestWarning) - - def run_wrapper(): - asyncio.run(start_process(True)) - - pkg.utils.context.get_thread_ctl().submit_sys_task( - run_wrapper - ) - - # 主线程循环 - while True: - try: - time.sleep(0xFF) - except: - stop() - pkg.utils.context.get_thread_ctl().shutdown() - - launch_args = sys.argv.copy() - if "--cov-report" not in launch_args: - import platform - if platform.system() == 'Windows': - cmd = "taskkill /F /PID {}".format(os.getpid()) - elif platform.system() in ['Linux', 'Darwin']: - cmd = "kill -9 {}".format(os.getpid()) - os.system(cmd) - else: - print("正常退出以生成覆盖率报告") - sys.exit(0) +asciiart = r""" + ___ ___ _ _ ___ ___ _____ + / _ \ / __| |_ __ _| |_ / __| _ \_ _| +| (_) | (__| ' \/ _` | _| (_ | _/ | | + \__\_\\___|_||_\__,_|\__|\___|_| |_| +⭐️开源地址: https://github.com/RockChinQ/QChatGPT +📖文档地址: https://q.rkcn.top +""" if __name__ == '__main__': - main() + print(asciiart) + from pkg.core import boot + asyncio.run(boot.main()) diff --git a/override-all.json b/override-all.json deleted file mode 100644 index ae3b7e10..00000000 --- a/override-all.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "comment": "这是override.json支持的字段全集, 关于override.json机制, 请查看https://github.com/RockChinQ/QChatGPT/pull/271", - "msg_source_adapter": "yirimirai", - "mirai_http_api_config": { - "adapter": "WebSocketAdapter", - "host": "localhost", - "port": 8080, - "verifyKey": "yirimirai", - "qq": 1234567890 - }, - "nakuru_config": { - "host": "localhost", - "port": 6700, - "http_port": 5700, - "token": "" - }, - "openai_config": { - "api_key": { - "default": "openai_api_key" - }, - "http_proxy": null, - "reverse_proxy": null - }, - "switch_strategy": "active", - "admin_qq": 0, - "default_prompt": { - "default": "如果用户之后想获取帮助,请你说“输入!help获取帮助”。" - }, - "preset_mode": "normal", - "response_rules": { - "default": { - "at": true, - "prefix": [ - "/ai", - "!ai", - "!ai", - "ai" - ], - "regexp": [], - "random_rate": 0.0 - } - }, - "ignore_rules": { - "prefix": [ - "/" - ], - "regexp": [] - }, - "income_msg_check": false, - "sensitive_word_filter": true, - "baidu_check": false, - "baidu_api_key": "", - "baidu_secret_key": "", - "inappropriate_message_tips": "[百度云]请珍惜机器人,当前返回内容不合规", - "encourage_sponsor_at_start": true, - "prompt_submit_length": 3072, - "auto_reset": true, - "completion_api_params": { - "model": "gpt-3.5-turbo", - "temperature": 0.9 - }, - "image_api_params": { - "model": "dall-e-2", - "size": "256x256" - }, - "trace_function_calls": false, - "quote_origin": false, - "at_sender": false, - "include_image_description": true, - "process_message_timeout": 120, - "show_prefix": false, - "force_delay_range": [ - 0, - 0 - ], - "blob_message_threshold": 256, - "blob_message_strategy": "forward", - "wait_last_done": true, - "font_path": "", - "retry_times": 3, - "hide_exce_info_to_user": false, - "session_expire_time": 1200, - "rate_limitation": { - "default": 60 - }, - "rate_limit_strategy": "drop", - "upgrade_dependencies": false, - "report_usage": true, - "logging_level": 20 -} \ No newline at end of file diff --git a/pkg/openai/api/__init__.py b/pkg/audit/center/__init__.py similarity index 100% rename from pkg/openai/api/__init__.py rename to pkg/audit/center/__init__.py diff --git a/pkg/audit/center/apigroup.py b/pkg/audit/center/apigroup.py new file mode 100644 index 00000000..10b6d8dd --- /dev/null +++ b/pkg/audit/center/apigroup.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import abc +import uuid +import json +import logging +import asyncio + +import aiohttp +import requests + +from ...core import app + + +class APIGroup(metaclass=abc.ABCMeta): + """API 组抽象类""" + _basic_info: dict = None + _runtime_info: dict = None + + prefix = None + + ap: app.Application + + def __init__(self, prefix: str, ap: app.Application): + self.prefix = prefix + self.ap = ap + + async def _do( + self, + method: str, + path: str, + data: dict = None, + params: dict = None, + headers: dict = {}, + **kwargs + ): + self._runtime_info['account_id'] = "-1" + + url = self.prefix + path + data = json.dumps(data) + headers['Content-Type'] = 'application/json' + + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method, + url, + data=data, + params=params, + headers=headers, + **kwargs + ) as resp: + self.ap.logger.debug("data: %s", data) + self.ap.logger.debug("ret: %s", await resp.text()) + + except Exception as e: + self.ap.logger.debug(f'上报失败: {e}') + + async def do( + self, + method: str, + path: str, + data: dict = None, + params: dict = None, + headers: dict = {}, + **kwargs + ) -> asyncio.Task: + """执行请求""" + asyncio.create_task(self._do(method, path, data, params, headers, **kwargs)) + + def gen_rid( + self + ): + """生成一个请求 ID""" + return str(uuid.uuid4()) + + def basic_info( + self + ): + """获取基本信息""" + basic_info = APIGroup._basic_info.copy() + basic_info['rid'] = self.gen_rid() + return basic_info + + def runtime_info( + self + ): + """获取运行时信息""" + return APIGroup._runtime_info diff --git a/pkg/qqbot/__init__.py b/pkg/audit/center/groups/__init__.py similarity index 100% rename from pkg/qqbot/__init__.py rename to pkg/audit/center/groups/__init__.py diff --git a/pkg/utils/center/groups/main.py b/pkg/audit/center/groups/main.py similarity index 70% rename from pkg/utils/center/groups/main.py rename to pkg/audit/center/groups/main.py index a4e5414a..3a31a65b 100644 --- a/pkg/utils/center/groups/main.py +++ b/pkg/audit/center/groups/main.py @@ -1,22 +1,22 @@ from __future__ import annotations from .. import apigroup -from ... import context +from ....core import app class V2MainDataAPI(apigroup.APIGroup): """主程序相关 数据API""" - def __init__(self, prefix: str): - super().__init__(prefix+"/main") + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap + super().__init__(prefix+"/main", ap) - def do(self, *args, **kwargs): - config = context.get_config_manager().data - if not config['report_usage']: + async def do(self, *args, **kwargs): + if not self.ap.system_cfg.data['report-usage']: return None - return super().do(*args, **kwargs) + return await super().do(*args, **kwargs) - def post_update_record( + async def post_update_record( self, spent_seconds: int, infer_reason: str, @@ -24,7 +24,7 @@ def post_update_record( new_version: str, ): """提交更新记录""" - return self.do( + return await self.do( "POST", "/update", data={ @@ -38,12 +38,12 @@ def post_update_record( } ) - def post_announcement_showed( + async def post_announcement_showed( self, ids: list[int], ): """提交公告已阅""" - return self.do( + return await self.do( "POST", "/announcement", data={ diff --git a/pkg/utils/center/groups/plugin.py b/pkg/audit/center/groups/plugin.py similarity index 69% rename from pkg/utils/center/groups/plugin.py rename to pkg/audit/center/groups/plugin.py index c7881b9d..627b116c 100644 --- a/pkg/utils/center/groups/plugin.py +++ b/pkg/audit/center/groups/plugin.py @@ -1,27 +1,27 @@ from __future__ import annotations +from ....core import app from .. import apigroup -from ... import context class V2PluginDataAPI(apigroup.APIGroup): """插件数据相关 API""" - def __init__(self, prefix: str): - super().__init__(prefix+"/plugin") + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap + super().__init__(prefix+"/plugin", ap) - def do(self, *args, **kwargs): - config = context.get_config_manager().data - if not config['report_usage']: + async def do(self, *args, **kwargs): + if not self.ap.system_cfg.data['report-usage']: return None - return super().do(*args, **kwargs) + return await super().do(*args, **kwargs) - def post_install_record( + async def post_install_record( self, plugin: dict ): """提交插件安装记录""" - return self.do( + return await self.do( "POST", "/install", data={ @@ -30,12 +30,12 @@ def post_install_record( } ) - def post_remove_record( + async def post_remove_record( self, plugin: dict ): """提交插件卸载记录""" - return self.do( + return await self.do( "POST", "/remove", data={ @@ -44,14 +44,14 @@ def post_remove_record( } ) - def post_update_record( + async def post_update_record( self, plugin: dict, old_version: str, new_version: str, ): """提交插件更新记录""" - return self.do( + return await self.do( "POST", "/update", data={ diff --git a/pkg/utils/center/groups/usage.py b/pkg/audit/center/groups/usage.py similarity index 79% rename from pkg/utils/center/groups/usage.py rename to pkg/audit/center/groups/usage.py index f966add4..8a8bdf04 100644 --- a/pkg/utils/center/groups/usage.py +++ b/pkg/audit/center/groups/usage.py @@ -1,22 +1,22 @@ from __future__ import annotations from .. import apigroup -from ... import context +from ....core import app class V2UsageDataAPI(apigroup.APIGroup): """使用量数据相关 API""" - def __init__(self, prefix: str): - super().__init__(prefix+"/usage") + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap + super().__init__(prefix+"/usage", ap) - def do(self, *args, **kwargs): - config = context.get_config_manager().data - if not config['report_usage']: + async def do(self, *args, **kwargs): + if not self.ap.system_cfg.data['report-usage']: return None - return super().do(*args, **kwargs) - - def post_query_record( + return await super().do(*args, **kwargs) + + async def post_query_record( self, session_type: str, session_id: str, @@ -27,7 +27,7 @@ def post_query_record( retry_times: int, ): """提交请求记录""" - return self.do( + return await self.do( "POST", "/query", data={ @@ -47,13 +47,13 @@ def post_query_record( } ) - def post_event_record( + async def post_event_record( self, plugins: list[dict], event_name: str, ): """提交事件触发记录""" - return self.do( + return await self.do( "POST", "/event", data={ @@ -66,14 +66,14 @@ def post_event_record( } ) - def post_function_record( + async def post_function_record( self, plugin: dict, function_name: str, function_description: str, ): """提交内容函数使用记录""" - return self.do( + return await self.do( "POST", "/function", data={ diff --git a/pkg/utils/center/v2.py b/pkg/audit/center/v2.py similarity index 70% rename from pkg/utils/center/v2.py rename to pkg/audit/center/v2.py index b1c0a3e6..70d51384 100644 --- a/pkg/utils/center/v2.py +++ b/pkg/audit/center/v2.py @@ -6,6 +6,7 @@ from .groups import main from .groups import usage from .groups import plugin +from ...core import app BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2" @@ -22,7 +23,7 @@ class V2CenterAPI: plugin: plugin.V2PluginDataAPI = None """插件 API 组""" - def __init__(self, basic_info: dict = None, runtime_info: dict = None): + def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None): """初始化""" logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) @@ -30,6 +31,7 @@ def __init__(self, basic_info: dict = None, runtime_info: dict = None): apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._runtime_info = runtime_info - self.main = main.V2MainDataAPI(BACKEND_URL) - self.usage = usage.V2UsageDataAPI(BACKEND_URL) - self.plugin = plugin.V2PluginDataAPI(BACKEND_URL) + self.main = main.V2MainDataAPI(BACKEND_URL, ap) + self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap) + self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap) + diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py deleted file mode 100644 index 01bb7f2d..00000000 --- a/pkg/audit/gatherer.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -使用量统计以及数据上报功能实现 -""" - -import hashlib -import json -import logging -import threading - -import requests - -from ..utils import context -from ..utils import updater - - -class DataGatherer: - """数据收集器""" - - usage = {} - """各api-key的使用量 - - 以key值md5为key,{ - "text": { - "gpt-3.5-turbo": 文字量:int, - }, - "image": { - "256x256": 图片数量:int, - } - }为值的字典""" - - version_str = "undetermined" - - def __init__(self): - self.load_from_db() - try: - self.version_str = updater.get_current_tag() # 从updater模块获取版本号 - except: - pass - - def get_usage(self, key_md5): - return self.usage[key_md5] if key_md5 in self.usage else {} - - def report_text_model_usage(self, model, total_tokens): - """调用方报告文字模型请求文字使用量""" - - key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存 - - if key_md5 not in self.usage: - self.usage[key_md5] = {} - - if "text" not in self.usage[key_md5]: - self.usage[key_md5]["text"] = {} - - if model not in self.usage[key_md5]["text"]: - self.usage[key_md5]["text"][model] = 0 - - length = total_tokens - self.usage[key_md5]["text"][model] += length - self.dump_to_db() - - def report_image_model_usage(self, size): - """调用方报告图片模型请求图片使用量""" - - key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() - - if key_md5 not in self.usage: - self.usage[key_md5] = {} - - if "image" not in self.usage[key_md5]: - self.usage[key_md5]["image"] = {} - - if size not in self.usage[key_md5]["image"]: - self.usage[key_md5]["image"][size] = 0 - - self.usage[key_md5]["image"][size] += 1 - self.dump_to_db() - - def get_text_length_of_key(self, key): - """获取指定api-key (明文) 的文字总使用量(本地记录)""" - key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - if key_md5 not in self.usage: - return 0 - if "text" not in self.usage[key_md5]: - return 0 - # 遍历其中所有模型,求和 - return sum(self.usage[key_md5]["text"].values()) - - def get_image_count_of_key(self, key): - """获取指定api-key (明文) 的图片总使用量(本地记录)""" - - key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - if key_md5 not in self.usage: - return 0 - if "image" not in self.usage[key_md5]: - return 0 - # 遍历其中所有模型,求和 - return sum(self.usage[key_md5]["image"].values()) - - def get_total_text_length(self): - """获取所有api-key的文字总使用量(本地记录)""" - total = 0 - for key in self.usage: - if "text" not in self.usage[key]: - continue - total += sum(self.usage[key]["text"].values()) - return total - - def dump_to_db(self): - context.get_database_manager().dump_usage_json(self.usage) - - def load_from_db(self): - json_str = context.get_database_manager().load_usage_json() - if json_str is not None: - self.usage = json.loads(json_str) diff --git a/pkg/qqbot/cmds/__init__.py b/pkg/command/__init__.py similarity index 100% rename from pkg/qqbot/cmds/__init__.py rename to pkg/command/__init__.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py new file mode 100644 index 00000000..1e622b97 --- /dev/null +++ b/pkg/command/cmdmgr.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import typing + +from ..core import app, entities as core_entities +from ..provider import entities as llm_entities +from . import entities, operator, errors +from ..config import manager as cfg_mgr + +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update + + +class CommandManager: + """命令管理器 + """ + + ap: app.Application + + cmd_list: list[operator.CommandOperator] + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + + # 设置各个类的路径 + def set_path(cls: operator.CommandOperator, ancestors: list[str]): + cls.path = '.'.join(ancestors + [cls.name]) + for op in operator.preregistered_operators: + if op.parent_class == cls: + set_path(op, ancestors + [cls.name]) + + for cls in operator.preregistered_operators: + if cls.parent_class is None: + set_path(cls, []) + + # 应用命令权限配置 + for cls in operator.preregistered_operators: + if cls.path in self.ap.command_cfg.data['privilege']: + cls.lowest_privilege = self.ap.command_cfg.data['privilege'][cls.path] + + # 实例化所有类 + self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] + + # 设置所有类的子节点 + for cmd in self.cmd_list: + cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] + + # 初始化所有类 + for cmd in self.cmd_list: + await cmd.initialize() + + async def _execute( + self, + context: entities.ExecuteContext, + operator_list: list[operator.CommandOperator], + operator: operator.CommandOperator = None + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + + found = False + if len(context.crt_params) > 0: + for oper in operator_list: + if (context.crt_params[0] == oper.name \ + or context.crt_params[0] in oper.alias) \ + and (oper.parent_class is None or oper.parent_class == operator.__class__): + found = True + + context.crt_command = context.crt_params[0] + context.crt_params = context.crt_params[1:] + + async for ret in self._execute( + context, + oper.children, + oper + ): + yield ret + break + + if not found: + if operator is None: + yield entities.CommandReturn( + error=errors.CommandNotFoundError(context.crt_params[0]) + ) + else: + if operator.lowest_privilege > context.privilege: + yield entities.CommandReturn( + error=errors.CommandPrivilegeError(operator.name) + ) + else: + async for ret in operator.execute(context): + yield ret + + + async def execute( + self, + command_text: str, + query: core_entities.Query, + session: core_entities.Session + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + + privilege = 1 + + if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']: + privilege = 2 + + ctx = entities.ExecuteContext( + query=query, + session=session, + command_text=command_text, + command='', + crt_command='', + params=command_text.split(' '), + crt_params=command_text.split(' '), + privilege=privilege + ) + + async for ret in self._execute( + ctx, + self.cmd_list + ): + yield ret diff --git a/pkg/command/entities.py b/pkg/command/entities.py new file mode 100644 index 00000000..f5f8bef5 --- /dev/null +++ b/pkg/command/entities.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import typing + +import pydantic +import mirai + +from ..core import app, entities as core_entities +from . import errors, operator + + +class CommandReturn(pydantic.BaseModel): + + text: typing.Optional[str] + """文本 + """ + + image: typing.Optional[mirai.Image] + + error: typing.Optional[errors.CommandError]= None + + class Config: + arbitrary_types_allowed = True + + +class ExecuteContext(pydantic.BaseModel): + + query: core_entities.Query + + session: core_entities.Session + + command_text: str + + command: str + + crt_command: str + + params: list[str] + + crt_params: list[str] + + privilege: int diff --git a/pkg/command/errors.py b/pkg/command/errors.py new file mode 100644 index 00000000..5bc253f6 --- /dev/null +++ b/pkg/command/errors.py @@ -0,0 +1,33 @@ + + +class CommandError(Exception): + + def __init__(self, message: str = None): + self.message = message + + def __str__(self): + return self.message + + +class CommandNotFoundError(CommandError): + + def __init__(self, message: str = None): + super().__init__("未知命令: "+message) + + +class CommandPrivilegeError(CommandError): + + def __init__(self, message: str = None): + super().__init__("权限不足: "+message) + + +class ParamNotEnoughError(CommandError): + + def __init__(self, message: str = None): + super().__init__("参数不足: "+message) + + +class CommandOperationError(CommandError): + + def __init__(self, message: str = None): + super().__init__("操作失败: "+message) diff --git a/pkg/command/operator.py b/pkg/command/operator.py new file mode 100644 index 00000000..48ca8daf --- /dev/null +++ b/pkg/command/operator.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app, entities as core_entities +from . import entities + + +preregistered_operators: list[typing.Type[CommandOperator]] = [] + + +def operator_class( + name: str, + help: str, + usage: str = None, + alias: list[str] = [], + privilege: int=1, # 1为普通用户,2为管理员 + parent_class: typing.Type[CommandOperator] = None +) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: + def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: + cls.name = name + cls.alias = alias + cls.help = help + cls.usage = usage + cls.parent_class = parent_class + cls.lowest_privilege = privilege + + preregistered_operators.append(cls) + + return cls + + return decorator + + +class CommandOperator(metaclass=abc.ABCMeta): + """命令算子 + """ + + ap: app.Application + + name: str + """名称,搜索到时若符合则使用""" + + path: str + """路径,所有父节点的name的连接,用于定义命令权限""" + + alias: list[str] + """同name""" + + help: str + """此节点的帮助信息""" + + usage: str = None + + parent_class: typing.Type[CommandOperator] | None = None + """父节点类。标记以供管理器在初始化时编织父子关系。""" + + lowest_privilege: int = 0 + """最低权限。若权限低于此值,则不予执行。""" + + children: list[CommandOperator] + """子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点, + 若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.children = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + pass diff --git a/pkg/qqbot/cmds/funcs/__init__.py b/pkg/command/operators/__init__.py similarity index 100% rename from pkg/qqbot/cmds/funcs/__init__.py rename to pkg/command/operators/__init__.py diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py new file mode 100644 index 00000000..17b5ed08 --- /dev/null +++ b/pkg/command/operators/cmd.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="cmd", + help='显示命令列表', + usage='!cmd\n!cmd <命令名称>' +) +class CmdOperator(operator.CommandOperator): + """命令列表 + """ + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if len(context.crt_params) == 0: + reply_str = "当前所有命令: \n\n" + + for cmd in self.ap.cmd_mgr.cmd_list: + if cmd.parent_class is None: + reply_str += f"{cmd.name}: {cmd.help}\n" + + reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助" + + yield entities.CommandReturn(text=reply_str.strip()) + + else: + cmd_name = context.crt_params[0] + + cmd = None + + for _cmd in self.ap.cmd_mgr.cmd_list: + if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): + cmd = _cmd + break + + if cmd is None: + yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) + else: + reply_str = f"{cmd.name}: {cmd.help}\n\n" + reply_str += f"使用方法: \n{cmd.usage}" + + yield entities.CommandReturn(text=reply_str.strip()) diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py new file mode 100644 index 00000000..ca7e404d --- /dev/null +++ b/pkg/command/operators/default.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="default", + help="操作情景预设", + usage='!default\n!default set <指定情景预设为默认>' +) +class DefaultOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + reply_str = "当前所有情景预设: \n\n" + + for prompt in self.ap.prompt_mgr.get_all_prompts(): + + content = "" + for msg in prompt.messages: + content += f" {msg.role}: {msg.content}" + + reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" + + reply_str += f"当前会话使用的是: {context.session.use_prompt_name}" + + yield entities.CommandReturn(text=reply_str.strip()) + + +@operator.operator_class( + name="set", + help="设置当前会话默认情景预设", + parent_class=DefaultOperator +) +class DefaultSetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py new file mode 100644 index 00000000..db865ff7 --- /dev/null +++ b/pkg/command/operators/delc.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="del", + help="删除当前会话的历史记录", + usage='!del <序号>\n!del all' +) +class DelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + delete_index = 0 + if len(context.crt_params) > 0: + try: + delete_index = int(context.crt_params[0]) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) + return + + if delete_index < 0 or delete_index >= len(context.session.conversations): + yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) + return + + # 倒序 + to_delete_index = len(context.session.conversations)-1-delete_index + + if context.session.conversations[to_delete_index] == context.session.using_conversation: + context.session.using_conversation = None + + del context.session.conversations[to_delete_index] + + yield entities.CommandReturn(text=f"已删除对话: {delete_index}") + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + + +@operator.operator_class( + name="all", + help="删除此会话的所有历史记录", + parent_class=DelOperator +) +class DelAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + context.session.conversations = [] + context.session.using_conversation = None + + yield entities.CommandReturn(text="已删除所有对话") \ No newline at end of file diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py new file mode 100644 index 00000000..33031bfb --- /dev/null +++ b/pkg/command/operators/func.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from typing import AsyncGenerator + +from .. import operator, entities, cmdmgr + + +@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') +class FuncOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> AsyncGenerator[entities.CommandReturn, None]: + reply_str = "当前已加载的内容函数: \n\n" + + index = 1 + + all_functions = await self.ap.tool_mgr.get_all_functions() + + for func in all_functions: + reply_str += "{}. {}{}:\n{}\n\n".format( + index, + ("(已禁用) " if not func.enable else ""), + func.name, + func.description, + ) + index += 1 + + yield entities.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py new file mode 100644 index 00000000..570e103c --- /dev/null +++ b/pkg/command/operators/help.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name='help', + help='显示帮助', + usage='!help\n!help <命令名称>' +) +class HelpOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + help = self.ap.system_cfg.data['help-message'] + + help += '\n发送命令 !cmd 可查看命令列表' + + yield entities.CommandReturn(text=help) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py new file mode 100644 index 00000000..8e3a5231 --- /dev/null +++ b/pkg/command/operators/last.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import typing +import datetime + + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="last", + help="切换到前一个对话", + usage='!last' +) +class LastOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的上一个会话 + for index in range(len(context.session.conversations)-1, -1, -1): + if context.session.conversations[index] == context.session.using_conversation: + if index == 0: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index-1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py new file mode 100644 index 00000000..258e0ee2 --- /dev/null +++ b/pkg/command/operators/list.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="list", + help="列出此会话中的所有历史对话", + usage='!list\n!list <页码>' +) +class ListOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + page = 0 + + if len(context.crt_params) > 0: + try: + page = int(context.crt_params[0]-1) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) + return + + record_per_page = 10 + + content = '' + + index = 0 + + using_conv_index = 0 + + for conv in context.session.conversations[::-1]: + time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S") + + if conv == context.session.using_conversation: + using_conv_index = index + + if index >= page * record_per_page and index < (page + 1) * record_per_page: + content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" + index += 1 + + if content == '': + content = '无' + else: + if context.session.using_conversation is None: + content += "\n当前处于新会话" + else: + content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" + + yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py new file mode 100644 index 00000000..8f4b5a5a --- /dev/null +++ b/pkg/command/operators/next.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="next", + help="切换到后一个对话", + usage='!next' +) +class NextOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的下一个会话 + for index in range(len(context.session.conversations)): + if context.session.conversations[index] == context.session.using_conversation: + if index == len(context.session.conversations)-1: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index+1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py new file mode 100644 index 00000000..b1cf6ee1 --- /dev/null +++ b/pkg/command/operators/plugin.py @@ -0,0 +1,237 @@ +from __future__ import annotations +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors +from ...core import app + + +@operator.operator_class( + name="plugin", + help="插件操作", + usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>" +) +class PluginOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + plugin_list = self.ap.plugin_mgr.plugins + reply_str = "所有插件({}):\n".format(len(plugin_list)) + idx = 0 + for plugin in plugin_list: + reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ + .format((idx+1), plugin.plugin_name, + "[已禁用]" if not plugin.enabled else "", + plugin.plugin_description, + plugin.plugin_version, plugin.plugin_author) + + # TODO 从元数据调远程地址 + + idx += 1 + + yield entities.CommandReturn(text=reply_str) + + +@operator.operator_class( + name="get", + help="安装插件", + privilege=2, + parent_class=PluginOperator +) +class PluginGetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) + else: + repo = context.crt_params[0] + + yield entities.CommandReturn(text="正在安装插件...") + + try: + await self.ap.plugin_mgr.install_plugin(repo) + yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e))) + + +@operator.operator_class( + name="update", + help="更新插件", + privilege=2, + parent_class=PluginOperator +) +class PluginUpdateOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is not None: + yield entities.CommandReturn(text="正在更新插件...") + await self.ap.plugin_mgr.update_plugin(plugin_name) + yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + +@operator.operator_class( + name="all", + help="更新所有插件", + privilege=2, + parent_class=PluginUpdateOperator +) +class PluginUpdateAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + try: + plugins = [ + p.plugin_name + for p in self.ap.plugin_mgr.plugins + ] + + if plugins: + yield entities.CommandReturn(text="正在更新插件...") + updated = [] + try: + for plugin_name in plugins: + await self.ap.plugin_mgr.update_plugin(plugin_name) + updated.append(plugin_name) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated))) + else: + yield entities.CommandReturn(text="没有可更新的插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + + +@operator.operator_class( + name="del", + help="删除插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is not None: + yield entities.CommandReturn(text="正在删除插件...") + await self.ap.plugin_mgr.uninstall_plugin(plugin_name) + yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) + + +async def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): + if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None: + for plugin in ap.plugin_mgr.plugins: + if plugin.plugin_name == plugin_name: + plugin.enabled = new_status + + for func in plugin.content_functions: + func.enable = new_status + + await ap.plugin_mgr.setting.dump_container_setting(ap.plugin_mgr.plugins) + + break + + return True + else: + return False + + +@operator.operator_class( + name="on", + help="启用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginEnableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if await update_plugin_status(plugin_name, True, self.ap): + yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) + + +@operator.operator_class( + name="off", + help="禁用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDisableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if await update_plugin_status(plugin_name, False, self.ap): + yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py new file mode 100644 index 00000000..29d688a6 --- /dev/null +++ b/pkg/command/operators/prompt.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="prompt", + help="查看当前对话的前文", + usage='!prompt' +) +class PromptOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + else: + reply_str = '当前对话所有内容:\n\n' + + for msg in context.session.using_conversation.messages: + reply_str += f"{msg.role}: {msg.content}\n" + + yield entities.CommandReturn(text=reply_str) \ No newline at end of file diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py new file mode 100644 index 00000000..6d930413 --- /dev/null +++ b/pkg/command/operators/resend.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="resend", + help="重发当前会话的最后一条消息", + usage='!resend' +) +class ResendOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + # 回滚到最后一条用户message前 + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandError("当前没有对话")) + else: + conv_msg = context.session.using_conversation.messages + + # 倒序一直删到最后一条用户message + while len(conv_msg) > 0 and conv_msg[-1].role != 'user': + conv_msg.pop() + + if len(conv_msg) > 0: + # 删除最后一条用户message + conv_msg.pop() + + # 不重发了,提示用户已删除就行了 + yield entities.CommandReturn(text="已删除最后一次请求记录") diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py new file mode 100644 index 00000000..5d1402ac --- /dev/null +++ b/pkg/command/operators/reset.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="reset", + help="重置当前会话", + usage='!reset' +) +class ResetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + context.session.using_conversation = None + + yield entities.CommandReturn(text="已重置当前会话") diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py new file mode 100644 index 00000000..524a26dd --- /dev/null +++ b/pkg/command/operators/update.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="update", + help="更新程序", + usage='!update', + privilege=2 +) +class UpdateCommand(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + try: + yield entities.CommandReturn(text="正在进行更新...") + if await self.ap.ver_mgr.update_all(): + yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") + else: + yield entities.CommandReturn(text="当前已是最新版本") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e))) \ No newline at end of file diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py new file mode 100644 index 00000000..ed248db9 --- /dev/null +++ b/pkg/command/operators/version.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import typing + +from .. import operator, cmdmgr, entities, errors + + +@operator.operator_class( + name="version", + help="显示版本信息", + usage='!version' +) +class VersionCommand(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}" + + try: + if await self.ap.ver_mgr.is_new_version_available(): + reply_str += "\n\n有新版本可用, 使用 !update 更新" + except: + pass + + yield entities.CommandReturn(text=reply_str.strip()) \ No newline at end of file diff --git a/pkg/qqbot/cmds/plugin/__init__.py b/pkg/config/impls/__init__.py similarity index 100% rename from pkg/qqbot/cmds/plugin/__init__.py rename to pkg/config/impls/__init__.py diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py new file mode 100644 index 00000000..544f1a85 --- /dev/null +++ b/pkg/config/impls/json.py @@ -0,0 +1,47 @@ +import os +import shutil +import json + +from .. import model as file_model + + +class JSONConfigFile(file_model.ConfigFile): + """JSON配置文件""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + def __init__(self, config_file_name: str, template_file_name: str) -> None: + self.config_file_name = config_file_name + self.template_file_name = template_file_name + + def exists(self) -> bool: + return os.path.exists(self.config_file_name) + + async def create(self): + shutil.copyfile(self.template_file_name, self.config_file_name) + + async def load(self) -> dict: + + if not self.exists(): + await self.create() + + with open(self.config_file_name, 'r', encoding='utf-8') as f: + cfg = json.load(f) + + # 从模板文件中进行补全 + with open(self.template_file_name, 'r', encoding='utf-8') as f: + template_cfg = json.load(f) + + for key in template_cfg: + if key not in cfg: + cfg[key] = template_cfg[key] + + return cfg + + async def save(self, cfg: dict): + with open(self.config_file_name, 'w', encoding='utf-8') as f: + json.dump(cfg, f, indent=4, ensure_ascii=False) \ No newline at end of file diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 53a6b099..b75f0202 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,5 +1,10 @@ +from __future__ import annotations + from . import model as file_model -from ..utils import context +from .impls import pymodule, json as json_file + + +managers: ConfigManager = [] class ConfigManager: @@ -14,10 +19,35 @@ class ConfigManager: def __init__(self, cfg_file: file_model.ConfigFile) -> None: self.file = cfg_file self.data = {} - context.set_config_manager(self) async def load_config(self): self.data = await self.file.load() async def dump_config(self): await self.file.save(self.data) + + +async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: + """加载Python模块配置文件""" + cfg_inst = pymodule.PythonModuleConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr + + +async def load_json_config(config_name: str, template_name: str) -> ConfigManager: + """加载JSON配置文件""" + cfg_inst = json_file.JSONConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/__init__.py b/pkg/core/__init__.py similarity index 100% rename from pkg/qqbot/cmds/session/__init__.py rename to pkg/core/__init__.py diff --git a/pkg/core/app.py b/pkg/core/app.py new file mode 100644 index 00000000..ab483448 --- /dev/null +++ b/pkg/core/app.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import logging +import asyncio +import traceback + +import aioconsole + +from ..platform import manager as im_mgr +from ..provider.session import sessionmgr as llm_session_mgr +from ..provider.requester import modelmgr as llm_model_mgr +from ..provider.sysprompt import sysprompt as llm_prompt_mgr +from ..provider.tools import toolmgr as llm_tool_mgr +from ..config import manager as config_mgr +from ..audit.center import v2 as center_mgr +from ..command import cmdmgr +from ..plugin import manager as plugin_mgr +from . import pool, controller +from ..pipeline import stagemgr +from ..utils import version as version_mgr, proxy as proxy_mgr + + +class Application: + im_mgr: im_mgr.PlatformManager = None + + cmd_mgr: cmdmgr.CommandManager = None + + sess_mgr: llm_session_mgr.SessionManager = None + + model_mgr: llm_model_mgr.ModelManager = None + + prompt_mgr: llm_prompt_mgr.PromptManager = None + + tool_mgr: llm_tool_mgr.ToolManager = None + + command_cfg: config_mgr.ConfigManager = None + + pipeline_cfg: config_mgr.ConfigManager = None + + platform_cfg: config_mgr.ConfigManager = None + + provider_cfg: config_mgr.ConfigManager = None + + system_cfg: config_mgr.ConfigManager = None + + ctr_mgr: center_mgr.V2CenterAPI = None + + plugin_mgr: plugin_mgr.PluginManager = None + + query_pool: pool.QueryPool = None + + ctrl: controller.Controller = None + + stage_mgr: stagemgr.StageManager = None + + ver_mgr: version_mgr.VersionManager = None + + proxy_mgr: proxy_mgr.ProxyManager = None + + logger: logging.Logger = None + + def __init__(self): + pass + + async def initialize(self): + pass + + async def run(self): + await self.plugin_mgr.load_plugins() + await self.plugin_mgr.initialize_plugins() + + tasks = [] + + try: + + + tasks = [ + asyncio.create_task(self.im_mgr.run()), + asyncio.create_task(self.ctrl.run()) + ] + + # async def interrupt(tasks): + # await asyncio.sleep(1.5) + # while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit': + # pass + # for task in tasks: + # task.cancel() + + # await interrupt(tasks) + + import signal + + def signal_handler(sig, frame): + for task in tasks: + task.cancel() + self.logger.info("程序退出.") + exit(0) + + signal.signal(signal.SIGINT, signal_handler) + + await asyncio.gather(*tasks, return_exceptions=True) + + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.error(f"应用运行致命异常: {e}") + self.logger.debug(f"Traceback: {traceback.format_exc()}") + diff --git a/pkg/core/boot.py b/pkg/core/boot.py new file mode 100644 index 00000000..3ecfa8ae --- /dev/null +++ b/pkg/core/boot.py @@ -0,0 +1,149 @@ +from __future__ import print_function + +import os +import sys + +from .bootutils import files +from .bootutils import deps +from .bootutils import log +from .bootutils import config + +from . import app +from . import pool +from . import controller +from ..pipeline import stagemgr +from ..audit import identifier +from ..provider.session import sessionmgr as llm_session_mgr +from ..provider.requester import modelmgr as llm_model_mgr +from ..provider.sysprompt import sysprompt as llm_prompt_mgr +from ..provider.tools import toolmgr as llm_tool_mgr +from ..platform import manager as im_mgr +from ..command import cmdmgr +from ..plugin import manager as plugin_mgr +from ..audit.center import v2 as center_v2 +from ..utils import version, proxy, announce + +use_override = False + + +async def make_app() -> app.Application: + global use_override + + generated_files = await files.generate_files() + + if generated_files: + print("以下文件不存在,已自动生成,请按需修改配置文件后重启:") + for file in generated_files: + print("-", file) + + sys.exit(0) + + missing_deps = await deps.check_deps() + + if missing_deps: + print("以下依赖包未安装,将自动安装,请完成后重启程序:") + for dep in missing_deps: + print("-", dep) + await deps.install_deps(missing_deps) + sys.exit(0) + + qcg_logger = await log.init_logging() + + # 生成标识符 + identifier.init() + + # ========== 加载配置文件 ========== + + command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") + pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") + platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") + provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") + system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") + + # ========== 构建应用实例 ========== + ap = app.Application() + ap.logger = qcg_logger + + ap.command_cfg = command_cfg + ap.pipeline_cfg = pipeline_cfg + ap.platform_cfg = platform_cfg + ap.provider_cfg = provider_cfg + ap.system_cfg = system_cfg + + proxy_mgr = proxy.ProxyManager(ap) + await proxy_mgr.initialize() + ap.proxy_mgr = proxy_mgr + + ver_mgr = version.VersionManager(ap) + await ver_mgr.initialize() + ap.ver_mgr = ver_mgr + + center_v2_api = center_v2.V2CenterAPI( + ap, + basic_info={ + "host_id": identifier.identifier["host_id"], + "instance_id": identifier.identifier["instance_id"], + "semantic_version": ver_mgr.get_current_version(), + "platform": sys.platform, + }, + runtime_info={ + "admin_id": "{}".format(system_cfg.data["admin-sessions"]), + "msg_source": str([ + adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown' + for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable'] + ]), + }, + ) + ap.ctr_mgr = center_v2_api + + # 发送公告 + ann_mgr = announce.AnnouncementManager(ap) + await ann_mgr.show_announcements() + + ap.query_pool = pool.QueryPool() + + await ap.ver_mgr.show_version_update() + + plugin_mgr_inst = plugin_mgr.PluginManager(ap) + await plugin_mgr_inst.initialize() + ap.plugin_mgr = plugin_mgr_inst + + cmd_mgr_inst = cmdmgr.CommandManager(ap) + await cmd_mgr_inst.initialize() + ap.cmd_mgr = cmd_mgr_inst + + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) + await llm_model_mgr_inst.initialize() + ap.model_mgr = llm_model_mgr_inst + + llm_session_mgr_inst = llm_session_mgr.SessionManager(ap) + await llm_session_mgr_inst.initialize() + ap.sess_mgr = llm_session_mgr_inst + + llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap) + await llm_prompt_mgr_inst.initialize() + ap.prompt_mgr = llm_prompt_mgr_inst + + llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) + await llm_tool_mgr_inst.initialize() + ap.tool_mgr = llm_tool_mgr_inst + + im_mgr_inst = im_mgr.PlatformManager(ap=ap) + await im_mgr_inst.initialize() + ap.im_mgr = im_mgr_inst + + stage_mgr = stagemgr.StageManager(ap) + await stage_mgr.initialize() + ap.stage_mgr = stage_mgr + + ctrl = controller.Controller(ap) + ap.ctrl = ctrl + + await ap.initialize() + + return ap + + +async def main(): + app_inst = await make_app() + await app_inst.run() diff --git a/pkg/qqbot/cmds/system/__init__.py b/pkg/core/bootutils/__init__.py similarity index 100% rename from pkg/qqbot/cmds/system/__init__.py rename to pkg/core/bootutils/__init__.py diff --git a/pkg/core/bootutils/config.py b/pkg/core/bootutils/config.py new file mode 100644 index 00000000..0addff08 --- /dev/null +++ b/pkg/core/bootutils/config.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import json + +from ...config import manager as config_mgr +from ...config.impls import pymodule + + +load_python_module_config = config_mgr.load_python_module_config +load_json_config = config_mgr.load_json_config + + +async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]: + override_json = json.load(open("override.json", "r", encoding="utf-8")) + overrided = [] + + config = cfg_mgr.data + for key in override_json: + if key in config: + config[key] = override_json[key] + overrided.append(key) + + return overrided diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py new file mode 100644 index 00000000..c4988657 --- /dev/null +++ b/pkg/core/bootutils/deps.py @@ -0,0 +1,32 @@ +import pip + +required_deps = { + "requests": "requests", + "openai": "openai", + "colorlog": "colorlog", + "mirai": "yiri-mirai-rc", + "PIL": "pillow", + "nakuru": "nakuru-project-idk", + "CallingGPT": "CallingGPT", + "tiktoken": "tiktoken", + "yaml": "pyyaml", + "aiohttp": "aiohttp", +} + + +async def check_deps() -> list[str]: + global required_deps + + missing_deps = [] + for dep in required_deps: + try: + __import__(dep) + except ImportError: + missing_deps.append(dep) + return missing_deps + +async def install_deps(deps: list[str]): + global required_deps + + for dep in deps: + pip.main(["install", required_deps[dep]]) diff --git a/pkg/core/bootutils/files.py b/pkg/core/bootutils/files.py new file mode 100644 index 00000000..975e3aad --- /dev/null +++ b/pkg/core/bootutils/files.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os +import shutil +import sys + + +required_files = { + "plugins/__init__.py": "templates/__init__.py", + "plugins/plugins.json": "templates/plugin-settings.json", + "data/config/command.json": "templates/command.json", + "data/config/pipeline.json": "templates/pipeline.json", + "data/config/platform.json": "templates/platform.json", + "data/config/provider.json": "templates/provider.json", + "data/config/system.json": "templates/system.json", + "data/config/sensitive-words.json": "templates/sensitive-words.json", + "data/scenario/default.json": "templates/scenario-template.json", +} + +required_paths = [ + "temp", + "data", + "data/prompts", + "data/scenario", + "data/logs", + "data/config", + "plugins" +] + +async def generate_files() -> list[str]: + global required_files, required_paths + + for required_paths in required_paths: + if not os.path.exists(required_paths): + os.mkdir(required_paths) + + generated_files = [] + for file in required_files: + if not os.path.exists(file): + shutil.copyfile(required_files[file], file) + generated_files.append(file) + + return generated_files diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py new file mode 100644 index 00000000..308ca8c4 --- /dev/null +++ b/pkg/core/bootutils/log.py @@ -0,0 +1,57 @@ +import logging +import os +import sys +import time + +import colorlog + + +log_colors_config = { + "DEBUG": "green", # cyan white + "INFO": "white", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "cyan", +} + + +async def init_logging() -> logging.Logger: + level = logging.INFO + + if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: + level = logging.DEBUG + + log_file_name = "data/logs/qcg-%s.log" % time.strftime( + "%Y-%m-%d-%H-%M-%S", time.localtime() + ) + + qcg_logger = logging.getLogger("qcg") + + qcg_logger.setLevel(level) + + color_formatter = colorlog.ColoredFormatter( + fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors=log_colors_config, + ) + + stream_handler = logging.StreamHandler(sys.stdout) + + log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)] + + for handler in log_handlers: + handler.setLevel(level) + handler.setFormatter(color_formatter) + qcg_logger.addHandler(handler) + + qcg_logger.debug("日志初始化完成,日志级别:%s" % level) + logging.basicConfig( + level=logging.INFO, # 设置日志输出格式 + format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", + # 日志输出的格式 + # -8表示占位符,让输出左对齐,输出长度都为8位 + datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式 + handlers=[logging.NullHandler()], + ) + + return qcg_logger diff --git a/pkg/qqbot/sources/__init__.py b/pkg/core/bootutils/misc.py similarity index 100% rename from pkg/qqbot/sources/__init__.py rename to pkg/core/bootutils/misc.py diff --git a/pkg/core/controller.py b/pkg/core/controller.py new file mode 100644 index 00000000..42ef435c --- /dev/null +++ b/pkg/core/controller.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import asyncio +import typing +import traceback + +from . import app, entities +from ..pipeline import entities as pipeline_entities +from ..plugin import events + + +class Controller: + """总控制器 + """ + ap: app.Application + + semaphore: asyncio.Semaphore = None + """请求并发控制信号量""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.semaphore = asyncio.Semaphore(self.ap.system_cfg.data['pipeline-concurrency']) + + async def consumer(self): + """事件处理循环 + """ + try: + while True: + selected_query: entities.Query = None + + # 取请求 + async with self.ap.query_pool: + queries: list[entities.Query] = self.ap.query_pool.queries + + for query in queries: + session = await self.ap.sess_mgr.get_session(query) + self.ap.logger.debug(f"Checking query {query} session {session}") + + if not session.semaphore.locked(): + selected_query = query + await session.semaphore.acquire() + + break + + if selected_query: # 找到了 + queries.remove(selected_query) + else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限 + await self.ap.query_pool.condition.wait() + continue + + if selected_query: + async def _process_query(selected_query): + async with self.semaphore: # 总并发上限 + await self.process_query(selected_query) + + async with self.ap.query_pool: + (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + # 通知其他协程,有新的请求可以处理了 + self.ap.query_pool.condition.notify_all() + + asyncio.create_task(_process_query(selected_query)) + except Exception as e: + # traceback.print_exc() + self.ap.logger.error(f"控制器循环出错: {e}") + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): + """检查输出 + """ + if result.user_notice: + await self.ap.im_mgr.send( + query.message_event, + result.user_notice, + query.adapter + ) + if result.debug_notice: + self.ap.logger.debug(result.debug_notice) + if result.console_notice: + self.ap.logger.info(result.console_notice) + if result.error_notice: + self.ap.logger.error(result.error_notice) + + async def _execute_from_stage( + self, + stage_index: int, + query: entities.Query, + ): + """从指定阶段开始执行 + + 如何看懂这里为什么这么写? + 去问 GPT-4: + Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], + 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, + 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 + Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: + + A B C D E F G + + 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: + + A B C D E F G + + 现在假设C返回的是AsyncGenerator,那么执行顺序是: + + A B C D E F G C D E F G C D E F G ... + Q3: 但是如果不止一个stage会返回生成器呢? + """ + i = stage_index + + while i < len(self.ap.stage_mgr.stage_containers): + stage_container = self.ap.stage_mgr.stage_containers[i] + + result = stage_container.inst.process(query, stage_container.inst_name) + + if isinstance(result, typing.Coroutine): + result = await result + + if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + await self._check_output(query, result) + + if result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif result.result_type == pipeline_entities.ResultType.CONTINUE: + query = result.new_query + elif isinstance(result, typing.AsyncGenerator): # 生成器 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + + async for sub_result in result: + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + await self._check_output(query, sub_result) + + if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + query = sub_result.new_query + await self._execute_from_stage(i + 1, query) + break + + i += 1 + + async def process_query(self, query: entities.Query): + """处理请求 + """ + self.ap.logger.debug(f"Processing query {query}") + + try: + await self._execute_from_stage(0, query) + except Exception as e: + self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}") + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + # traceback.print_exc() + finally: + self.ap.logger.debug(f"Query {query} processed") + + async def run(self): + """运行控制器 + """ + await self.consumer() diff --git a/pkg/core/entities.py b/pkg/core/entities.py new file mode 100644 index 00000000..dacb64e0 --- /dev/null +++ b/pkg/core/entities.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import enum +import typing +import datetime +import asyncio + +import pydantic +import mirai + +from ..provider import entities as llm_entities +from ..provider.requester import entities +from ..provider.sysprompt import entities as sysprompt_entities +from ..provider.tools import entities as tools_entities +from ..platform import adapter as msadapter + + +class LauncherTypes(enum.Enum): + + PERSON = 'person' + """私聊""" + + GROUP = 'group' + """群聊""" + + +class Query(pydantic.BaseModel): + """一次请求的信息封装""" + + query_id: int + """请求ID,添加进请求池时生成""" + + launcher_type: LauncherTypes + """会话类型,platform设置""" + + launcher_id: int + """会话ID,platform设置""" + + sender_id: int + """发送者ID,platform设置""" + + message_event: mirai.MessageEvent + """事件,platform收到的事件""" + + message_chain: mirai.MessageChain + """消息链,platform收到的消息链""" + + adapter: msadapter.MessageSourceAdapter + """适配器对象""" + + session: typing.Optional[Session] = None + """会话对象,由前置处理器设置""" + + messages: typing.Optional[list[llm_entities.Message]] = [] + """历史消息列表,由前置处理器设置""" + + prompt: typing.Optional[sysprompt_entities.Prompt] = None + """情景预设内容,由前置处理器设置""" + + user_message: typing.Optional[llm_entities.Message] = None + """此次请求的用户消息对象,由前置处理器设置""" + + use_model: typing.Optional[entities.LLMModelInfo] = None + """使用的模型,由前置处理器设置""" + + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None + """使用的函数,由前置处理器设置""" + + resp_messages: typing.Optional[list[llm_entities.Message]] = [] + """由provider生成的回复消息对象列表""" + + resp_message_chain: typing.Optional[mirai.MessageChain] = None + """回复消息链,从resp_messages包装而得""" + + class Config: + arbitrary_types_allowed = True + + +class Conversation(pydantic.BaseModel): + """对话""" + + prompt: sysprompt_entities.Prompt + + messages: list[llm_entities.Message] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + use_model: entities.LLMModelInfo + + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] + + +class Session(pydantic.BaseModel): + """会话""" + launcher_type: LauncherTypes + + launcher_id: int + + sender_id: typing.Optional[int] = 0 + + use_prompt_name: typing.Optional[str] = 'default' + + using_conversation: typing.Optional[Conversation] = None + + conversations: typing.Optional[list[Conversation]] = [] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + semaphore: typing.Optional[asyncio.Semaphore] = None + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/core/pool.py b/pkg/core/pool.py new file mode 100644 index 00000000..5c8000dd --- /dev/null +++ b/pkg/core/pool.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import asyncio + +import mirai + +from . import entities +from ..platform import adapter as msadapter + + +class QueryPool: + + query_id_counter: int = 0 + + pool_lock: asyncio.Lock + + queries: list[entities.Query] + + condition: asyncio.Condition + + def __init__(self): + self.query_id_counter = 0 + self.pool_lock = asyncio.Lock() + self.queries = [] + self.condition = asyncio.Condition(self.pool_lock) + + async def add_query( + self, + launcher_type: entities.LauncherTypes, + launcher_id: int, + sender_id: int, + message_event: mirai.MessageEvent, + message_chain: mirai.MessageChain, + adapter: msadapter.MessageSourceAdapter + ) -> entities.Query: + async with self.condition: + query = entities.Query( + query_id=self.query_id_counter, + launcher_type=launcher_type, + launcher_id=launcher_id, + sender_id=sender_id, + message_event=message_event, + message_chain=message_chain, + resp_messages=[], + resp_message_chain=None, + adapter=adapter + ) + self.queries.append(query) + self.query_id_counter += 1 + self.condition.notify_all() + + async def __aenter__(self): + await self.pool_lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.pool_lock.release() diff --git a/pkg/database/__init__.py b/pkg/database/__init__.py deleted file mode 100644 index c40dc210..00000000 --- a/pkg/database/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -数据库操作封装 -""" \ No newline at end of file diff --git a/pkg/database/manager.py b/pkg/database/manager.py deleted file mode 100644 index ad44d512..00000000 --- a/pkg/database/manager.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -数据库管理模块 -""" -import hashlib -import json -import logging -import time - -import sqlite3 - -from ..utils import context - - -class DatabaseManager: - """封装数据库底层操作,并提供方法给上层使用""" - - conn = None - cursor = None - - def __init__(self): - - self.reconnect() - - context.set_database_manager(self) - - # 连接到数据库文件 - def reconnect(self): - """连接到数据库""" - self.conn = sqlite3.connect('database.db', check_same_thread=False) - self.cursor = self.conn.cursor() - - def close(self): - self.conn.close() - - def __execute__(self, *args, **kwargs) -> sqlite3.Cursor: - # logging.debug('SQL: {}'.format(sql)) - logging.debug('SQL: {}'.format(args)) - c = self.cursor.execute(*args, **kwargs) - self.conn.commit() - return c - - # 初始化数据库的函数 - def initialize_database(self): - """创建数据表""" - - self.__execute__(""" - create table if not exists `sessions` ( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `name` varchar(255) not null, - `type` varchar(255) not null, - `number` bigint not null, - `create_timestamp` bigint not null, - `last_interact_timestamp` bigint not null, - `status` varchar(255) not null default 'on_going', - `default_prompt` text not null default '', - `prompt` text not null, - `token_counts` text not null default '[]' - ) - """) - - # 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段 - self.__execute__("PRAGMA table_info('sessions')") - columns = self.cursor.fetchall() - has_default_prompt = False - has_token_counts = False - for field in columns: - if field[1] == 'default_prompt': - has_default_prompt = True - if field[1] == 'token_counts': - has_token_counts = True - if has_default_prompt and has_token_counts: - break - if not has_default_prompt: - self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") - if not has_token_counts: - self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'") - - - self.__execute__(""" - create table if not exists `account_fee`( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `key_md5` varchar(255) not null, - `timestamp` bigint not null, - `fee` DECIMAL(12,6) not null - ) - """) - - self.__execute__(""" - create table if not exists `account_usage`( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `json` text not null - ) - """) - # print('Database initialized.') - - # session持久化 - def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''): - """持久化指定session""" - - # 检查是否已经有了此name和create_timestamp的session - # 如果有,就更新prompt和last_interact_timestamp - # 如果没有,就插入一条新的记录 - self.__execute__(""" - select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} - """.format(subject_type, subject_number, create_timestamp)) - count = self.cursor.fetchone()[0] - if count == 0: - - sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`) - values (?, ?, ?, ?, ?, ?, ?, ?) - """ - - self.__execute__(sql, - ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt, default_prompt, token_counts)) - else: - sql = """ - update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ? - where `type` = ? and `number` = ? and `create_timestamp` = ? - """ - - self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type, - subject_number, create_timestamp)) - - # 显式关闭一个session - def explicit_close_session(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - def set_session_ongoing(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - # 设置session为过期 - def set_session_expired(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - # 从数据库加载还没过期的session数据 - def load_valid_sessions(self) -> dict: - # 从数据库中加载所有还没过期的session - config = context.get_config_manager().data - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `last_interact_timestamp` > {} - """.format(int(time.time()) - config['session_expire_time'])) - results = self.cursor.fetchall() - sessions = {} - for result in results: - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 - if status == 'on_going': - sessions[session_name] = { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - else: - if session_name in sessions: - del sessions[session_name] - - return sessions - - # 获取此session_name前一个session的数据 - def last_session(self, session_name: str, cursor_timestamp: int): - - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc - limit 1 - """.format(session_name, cursor_timestamp)) - results = self.cursor.fetchall() - if len(results) == 0: - return None - result = results[0] - - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - return { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - - # 获取此session_name后一个session的数据 - def next_session(self, session_name: str, cursor_timestamp: int): - - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc - limit 1 - """.format(session_name, cursor_timestamp)) - results = self.cursor.fetchall() - if len(results) == 0: - return None - result = results[0] - - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - return { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - - # 列出与某个对象的所有对话session - def list_history(self, session_name: str, capacity: int, page: int): - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} - """.format(session_name, capacity, capacity * page)) - results = self.cursor.fetchall() - sessions = [] - for result in results: - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - sessions.append({ - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - }) - - return sessions - - def delete_history(self, session_name: str, index: int) -> bool: - # 删除倒序第index个session - # 查找其id再删除 - self.__execute__(""" - delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {}) - """.format(session_name, index)) - - return self.cursor.rowcount == 1 - - def delete_all_history(self, session_name: str) -> bool: - self.__execute__(""" - delete from `sessions` where `name` = '{}' - """.format(session_name)) - return self.cursor.rowcount > 0 - - def delete_all_session_history(self) -> bool: - self.__execute__(""" - delete from `sessions` - """) - return self.cursor.rowcount > 0 - - # 将apikey的使用量存进数据库 - def dump_api_key_usage(self, api_keys: dict, usage: dict): - logging.debug('dumping api key usage...') - logging.debug(api_keys) - logging.debug(usage) - for api_key in api_keys: - # 计算key的md5值 - key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest() - # 获取使用量 - usage_count = 0 - if key_md5 in usage: - usage_count = usage[key_md5] - # 将使用量存进数据库 - # 先检查是否已存在 - self.__execute__(""" - select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5)) - result = self.cursor.fetchone() - if result[0] == 0: - # 不存在则插入 - self.__execute__(""" - insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {}) - """.format(key_md5, usage_count, int(time.time()))) - else: - # 存在则更新,timestamp设置为当前 - self.__execute__(""" - update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}' - """.format(usage_count, int(time.time()), key_md5)) - - def load_api_key_usage(self): - self.__execute__(""" - select `key_md5`, `usage` from `api_key_usage` - """) - results = self.cursor.fetchall() - usage = {} - for result in results: - key_md5 = result[0] - usage_count = result[1] - usage[key_md5] = usage_count - return usage - - def dump_usage_json(self, usage: dict): - - json_str = json.dumps(usage) - self.__execute__(""" - select count(*) from `account_usage`""") - result = self.cursor.fetchone() - if result[0] == 0: - # 不存在则插入 - self.__execute__(""" - insert into `account_usage` (`json`) values ('{}') - """.format(json_str)) - else: - # 存在则更新 - self.__execute__(""" - update `account_usage` set `json` = '{}' where `id` = 1 - """.format(json_str)) - - def load_usage_json(self): - self.__execute__(""" - select `json` from `account_usage` order by id desc limit 1 - """) - result = self.cursor.fetchone() - if result is None: - return None - else: - return result[0] diff --git a/pkg/openai/api/chat_completion.py b/pkg/openai/api/chat_completion.py deleted file mode 100644 index 1e0e1bc5..00000000 --- a/pkg/openai/api/chat_completion.py +++ /dev/null @@ -1,232 +0,0 @@ -import json -import logging - -import openai -from openai.types.chat import chat_completion_message - -from .model import RequestBase -from .. import funcmgr -from ...plugin import host -from ...utils import context - - -class ChatCompletionRequest(RequestBase): - """调用ChatCompletion接口的请求类。 - - 此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。 - 若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。 - """ - - model: str - messages: list[dict[str, str]] - kwargs: dict - - stopped: bool = False - - pending_func_call: chat_completion_message.FunctionCall = None - - pending_msg: str - - def flush_pending_msg(self): - self.append_message( - role="assistant", - content=self.pending_msg - ) - self.pending_msg = "" - - def append_message(self, role: str, content: str, name: str=None, function_call: dict=None): - msg = { - "role": role, - "content": content - } - - if name is not None: - msg['name'] = name - - if function_call is not None: - msg['function_call'] = function_call - - self.messages.append(msg) - - def __init__( - self, - client: openai.Client, - model: str, - messages: list[dict[str, str]], - **kwargs - ): - self.client = client - self.model = model - self.messages = messages.copy() - - self.kwargs = kwargs - - self.req_func = self.client.chat.completions.create - - self.pending_func_call = None - - self.stopped = False - - self.pending_msg = "" - - def __iter__(self): - return self - - def __next__(self) -> dict: - if self.stopped: - raise StopIteration() - - if self.pending_func_call is None: # 没有待处理的函数调用请求 - - args = { - "model": self.model, - "messages": self.messages, - } - - funcs = funcmgr.get_func_schema_list() - - if len(funcs) > 0: - args['functions'] = funcs - - # 拼接kwargs - args = {**args, **self.kwargs} - - from openai.types.chat import chat_completion - - resp: chat_completion.ChatCompletion = self._req(**args) - - choice0 = resp.choices[0] - - # 如果不是函数调用,且finish_reason为stop,则停止迭代 - if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop" - self.stopped = True - - if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None: - self.pending_func_call = choice0.message.function_call - - self.append_message( - role="assistant", - content=choice0.message.content, - function_call=choice0.message.function_call - ) - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "function_call", - "content": choice0.message.content, - "function_call": { - "name": choice0.message.function_call.name, - "arguments": choice0.message.function_call.arguments - } - }, - "finish_reason": "function_call" - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } - else: - - # self.pending_msg += choice0['message']['content'] - # 普通回复一定处于最后方,故不用再追加进内部messages - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "text", - "content": choice0.message.content - }, - "finish_reason": choice0.finish_reason - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } - else: # 处理函数调用请求 - - cp_pending_func_call = self.pending_func_call.copy() - - self.pending_func_call = None - - func_name = cp_pending_func_call.name - arguments = {} - - try: - - try: - arguments = json.loads(cp_pending_func_call.arguments) - # 若不是json格式的异常处理 - except json.decoder.JSONDecodeError: - # 获取函数的参数列表 - func_schema = funcmgr.get_func_schema(func_name) - - arguments = { - func_schema['parameters']['required'][0]: cp_pending_func_call.arguments - } - - logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments)) - - # 执行函数调用 - ret = "" - try: - ret = funcmgr.execute_function(func_name, arguments) - - logging.info("函数执行完成。") - except Exception as e: - ret = "error: execute function failed: {}".format(str(e)) - logging.error("函数执行失败: {}".format(str(e))) - - # 上报数据 - plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0]) - audit_func_name = func_name.split('-')[1] - audit_func_desc = funcmgr.get_func_schema(func_name)['description'] - context.get_center_v2_api().usage.post_function_record( - plugin=plugin_info, - function_name=audit_func_name, - function_description=audit_func_desc, - ) - - self.append_message( - role="function", - content=json.dumps(ret, ensure_ascii=False), - name=func_name - ) - - return { - "id": -1, - "choices": [ - { - "index": -1, - "message": { - "role": "function", - "type": "function_return", - "function_name": func_name, - "content": json.dumps(ret, ensure_ascii=False) - }, - "finish_reason": "function_return" - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } - } - - except funcmgr.ContentFunctionNotFoundError: - raise Exception("没有找到函数: {}".format(func_name)) diff --git a/pkg/openai/api/completion.py b/pkg/openai/api/completion.py deleted file mode 100644 index d14e91f4..00000000 --- a/pkg/openai/api/completion.py +++ /dev/null @@ -1,100 +0,0 @@ -import openai -from openai.types import completion, completion_choice - -from . import model - - -class CompletionRequest(model.RequestBase): - """调用Completion接口的请求类。 - - 调用方可以一直next completion直到finish_reason为stop。 - """ - - model: str - prompt: str - kwargs: dict - - stopped: bool = False - - def __init__( - self, - client: openai.Client, - model: str, - messages: list[dict[str, str]], - **kwargs - ): - self.client = client - self.model = model - self.prompt = "" - - for message in messages: - self.prompt += message["role"] + ": " + message["content"] + "\n" - - self.prompt += "assistant: " - - self.kwargs = kwargs - - self.req_func = self.client.completions.create - - def __iter__(self): - return self - - def __next__(self) -> dict: - """调用Completion接口,返回生成的文本 - - { - "id": "id", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "type": "text", - "content": "message" - }, - "finish_reason": "reason" - } - ], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } - } - """ - - if self.stopped: - raise StopIteration() - - resp: completion.Completion = self._req( - model=self.model, - prompt=self.prompt, - **self.kwargs - ) - - if resp.choices[0].finish_reason == "stop": - self.stopped = True - - choice0: completion_choice.CompletionChoice = resp.choices[0] - - self.prompt += choice0.text - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "text", - "content": choice0.text - }, - "finish_reason": choice0.finish_reason - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } diff --git a/pkg/openai/api/model.py b/pkg/openai/api/model.py deleted file mode 100644 index 0a1f6a3a..00000000 --- a/pkg/openai/api/model.py +++ /dev/null @@ -1,40 +0,0 @@ -# 定义不同接口请求的模型 -import logging - -import openai - -from ...utils import context - - -class RequestBase: - - client: openai.Client - - req_func: callable - - def __init__(self, *args, **kwargs): - raise NotImplementedError - - def _next_key(self): - switched, name = context.get_openai_manager().key_mgr.auto_switch() - logging.debug("切换api-key: switched={}, name={}".format(switched, name)) - self.client.api_key = context.get_openai_manager().key_mgr.get_using_key() - - def _req(self, **kwargs): - """处理代理问题""" - logging.debug("请求接口参数: %s", str(kwargs)) - config = context.get_config_manager().data - - ret = self.req_func(**kwargs) - logging.debug("接口请求返回:%s", str(ret)) - - if config['switch_strategy'] == 'active': - self._next_key() - - return ret - - def __iter__(self): - raise self - - def __next__(self): - raise NotImplementedError diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py deleted file mode 100644 index 247fb158..00000000 --- a/pkg/openai/dprompt.py +++ /dev/null @@ -1,134 +0,0 @@ -# 多情景预设值管理 -import json -import logging -import os - -from ..utils import context - -# __current__ = "default" -# """当前默认使用的情景预设的名称 - -# 由管理员使用`!default <名称>`命令切换 -# """ - -# __prompts_from_files__ = {} -# """从文件中读取的情景预设值""" - -# __scenario_from_files__ = {} - - -class ScenarioMode: - """情景预设模式抽象类""" - - using_prompt_name = "default" - """新session创建时使用的prompt名称""" - - prompts: dict[str, list] = {} - - def __init__(self): - logging.debug("prompts: {}".format(self.prompts)) - - def list(self) -> dict[str, list]: - """获取所有情景预设的名称及内容""" - return self.prompts - - def get_prompt(self, name: str) -> tuple[list, str]: - """获取指定情景预设的名称及内容""" - for key in self.prompts: - if key.startswith(name): - return self.prompts[key], key - raise Exception("没有找到情景预设: {}".format(name)) - - def set_using_name(self, name: str) -> str: - """设置默认情景预设""" - for key in self.prompts: - if key.startswith(name): - self.using_prompt_name = key - return key - raise Exception("没有找到情景预设: {}".format(name)) - - def get_full_name(self, name: str) -> str: - """获取完整的情景预设名称""" - for key in self.prompts: - if key.startswith(name): - return key - raise Exception("没有找到情景预设: {}".format(name)) - - def get_using_name(self) -> str: - """获取默认情景预设""" - return self.using_prompt_name - - -class NormalScenarioMode(ScenarioMode): - """普通情景预设模式""" - - def __init__(self): - config = context.get_config_manager().data - - # 加载config中的default_prompt值 - if type(config['default_prompt']) == str: - self.using_prompt_name = "default" - self.prompts = {"default": [ - { - "role": "system", - "content": config['default_prompt'] - } - ]} - - elif type(config['default_prompt']) == dict: - for key in config['default_prompt']: - self.prompts[key] = [ - { - "role": "system", - "content": config['default_prompt'][key] - } - ] - - # 从prompts/目录下的文件中载入 - # 遍历文件 - for file in os.listdir("prompts"): - with open(os.path.join("prompts", file), encoding="utf-8") as f: - self.prompts[file] = [ - { - "role": "system", - "content": f.read() - } - ] - - -class FullScenarioMode(ScenarioMode): - """完整情景预设模式""" - - def __init__(self): - """从json读取所有""" - # 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值 - for file in os.listdir("scenario"): - if file == "default-template.json": - continue - with open(os.path.join("scenario", file), encoding="utf-8") as f: - self.prompts[file] = json.load(f)["prompt"] - - super().__init__() - - -scenario_mode_mapping = {} -"""情景预设模式名称与对象的映射""" - - -def register_all(): - """注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载""" - global scenario_mode_mapping - scenario_mode_mapping = { - "normal": NormalScenarioMode(), - "full_scenario": FullScenarioMode() - } - - -def mode_inst() -> ScenarioMode: - """获取指定名称的情景预设模式对象""" - config = context.get_config_manager().data - - if config['preset_mode'] == "default": - config['preset_mode'] = "normal" - - return scenario_mode_mapping[config['preset_mode']] diff --git a/pkg/openai/funcmgr.py b/pkg/openai/funcmgr.py deleted file mode 100644 index 50932917..00000000 --- a/pkg/openai/funcmgr.py +++ /dev/null @@ -1,46 +0,0 @@ -# 封装了function calling的一些支持函数 -import logging - -from ..plugin import host - - -class ContentFunctionNotFoundError(Exception): - pass - - -def get_func_schema_list() -> list: - """从plugin包中的函数结构中获取并处理成受GPT支持的格式""" - if not host.__enable_content_functions__: - return [] - - schemas = [] - - for func in host.__callable_functions__: - if func['enabled']: - fun_cp = func.copy() - - del fun_cp['enabled'] - - schemas.append(fun_cp) - - return schemas - -def get_func(name: str) -> callable: - if name not in host.__function_inst_map__: - raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name)) - - return host.__function_inst_map__[name] - -def get_func_schema(name: str) -> dict: - for func in host.__callable_functions__: - if func['name'] == name: - return func - raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name)) - -def execute_function(name: str, kwargs: dict) -> any: - """执行函数调用""" - - logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs)) - - func = get_func(name) - return func(**kwargs) diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py deleted file mode 100644 index af560b29..00000000 --- a/pkg/openai/keymgr.py +++ /dev/null @@ -1,103 +0,0 @@ -# 此模块提供了维护api-key的各种功能 -import hashlib -import logging - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models - - -class KeysManager: - api_key = {} - """所有api-key""" - - using_key = "" - """当前使用的api-key""" - - alerted = [] - """已提示过超额的key - - 记录在此以避免重复提示 - """ - - exceeded = [] - """已超额的key - - 供自动切换功能识别 - """ - - def get_using_key(self): - return self.using_key - - def get_using_key_md5(self): - return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - - def __init__(self, api_key): - - assert type(api_key) == dict - self.api_key = api_key - # 从usage中删除未加载的api-key的记录 - # 不删了,也许会运行时添加曾经有记录的api-key - - self.auto_switch() - - def auto_switch(self) -> tuple[bool, str]: - """尝试切换api-key - - Returns: - 是否切换成功, 切换后的api-key的别名 - """ - - index = 0 - - for key_name in self.api_key: - if self.api_key[key_name] == self.using_key: - break - - index += 1 - - # 从当前key开始向后轮询 - start_index = index - index += 1 - if index >= len(self.api_key): - index = 0 - - while index != start_index: - - key_name = list(self.api_key.keys())[index] - - if self.api_key[key_name] not in self.exceeded: - self.using_key = self.api_key[key_name] - - logging.debug("使用api-key:" + key_name) - - # 触发插件事件 - args = { - "key_name": key_name, - "key_list": self.api_key.keys() - } - _ = plugin_host.emit(plugin_models.KeySwitched, **args) - - return True, key_name - - index += 1 - if index >= len(self.api_key): - index = 0 - - self.using_key = list(self.api_key.values())[start_index] - logging.debug("使用api-key:" + list(self.api_key.keys())[start_index]) - - return False, list(self.api_key.keys())[start_index] - - def add(self, key_name, key): - self.api_key[key_name] = key - - def set_current_exceeded(self): - """设置当前使用的api-key使用量超限""" - self.exceeded.append(self.using_key) - - def get_key_name(self, api_key): - """根据api-key获取其别名""" - for key_name in self.api_key: - if self.api_key[key_name] == api_key: - return key_name - return "" diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py deleted file mode 100644 index 8b2b6145..00000000 --- a/pkg/openai/manager.py +++ /dev/null @@ -1,90 +0,0 @@ -import logging - -import openai -from openai.types import images_response - -from ..openai import keymgr -from ..utils import context -from ..audit import gatherer -from ..openai import modelmgr -from ..openai.api import model as api_model - - -class OpenAIInteract: - """OpenAI 接口封装 - - 将文字接口和图片接口封装供调用方使用 - """ - - key_mgr: keymgr.KeysManager = None - - audit_mgr: gatherer.DataGatherer = None - - default_image_api_params = { - "size": "256x256", - } - - client: openai.Client = None - - def __init__(self, api_key: str): - - self.key_mgr = keymgr.KeysManager(api_key) - self.audit_mgr = gatherer.DataGatherer() - - # logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length()) - - self.client = openai.Client( - api_key=self.key_mgr.get_using_key(), - base_url=openai.base_url - ) - - context.set_openai_manager(self) - - def request_completion(self, messages: list): - """请求补全接口回复= - """ - # 选择接口请求类 - config = context.get_config_manager().data - - request: api_model.RequestBase - - model: str = config['completion_api_params']['model'] - - cp_parmas = config['completion_api_params'].copy() - del cp_parmas['model'] - - request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) - - # 请求接口 - for resp in request: - - if resp['usage']['total_tokens'] > 0: - self.audit_mgr.report_text_model_usage( - model, - resp['usage']['total_tokens'] - ) - - yield resp - - def request_image(self, prompt) -> images_response.ImagesResponse: - """请求图片接口回复 - - Parameters: - prompt (str): 提示语 - - Returns: - dict: 响应 - """ - config = context.get_config_manager().data - params = config['image_api_params'] - - response = self.client.images.generate( - prompt=prompt, - n=1, - **params - ) - - self.audit_mgr.report_image_model_usage(params['size']) - - return response - diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py deleted file mode 100644 index 0abd2d16..00000000 --- a/pkg/openai/modelmgr.py +++ /dev/null @@ -1,139 +0,0 @@ -"""OpenAI 接口底层封装 - -目前使用的对话接口有: -ChatCompletion - gpt-3.5-turbo 等模型 -Completion - text-davinci-003 等模型 -此模块封装此两个接口的请求实现,为上层提供统一的调用方式 -""" -import tiktoken -import openai - -from ..openai.api import model as api_model -from ..openai.api import completion as api_completion -from ..openai.api import chat_completion as api_chat_completion - -COMPLETION_MODELS = { - "gpt-3.5-turbo-instruct", -} - -CHAT_COMPLETION_MODELS = { - # GPT 4 系列 - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-32k", - "gpt-4-0613", - "gpt-4-32k-0613", - "gpt-4-0314", # legacy - "gpt-4-32k-0314", # legacy - # GPT 3.5 系列 - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", # legacy - "gpt-3.5-turbo-16k-0613", # legacy - "gpt-3.5-turbo-0301", # legacy - # One-API 接入 - "SparkDesk", - "chatglm_pro", - "chatglm_std", - "chatglm_lite", - "qwen-v1", - "qwen-plus-v1", - "ERNIE-Bot", - "ERNIE-Bot-turbo", - "gemini-pro", -} - -EDIT_MODELS = { - -} - -IMAGE_MODELS = { - -} - - -def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase: - if model_name in CHAT_COMPLETION_MODELS: - return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args) - elif model_name in COMPLETION_MODELS: - return api_completion.CompletionRequest(client, model_name, messages, **args) - raise ValueError("不支持模型[{}],请检查配置文件".format(model_name)) - - -def count_chat_completion_tokens(messages: list, model: str) -> int: - """Return the number of tokens used by a list of messages.""" - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - "SparkDesk", - "chatglm_pro", - "chatglm_std", - "chatglm_lite", - "qwen-v1", - "qwen-plus-v1", - "ERNIE-Bot", - "ERNIE-Bot-turbo", - "gemini-pro", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif "gpt-3.5-turbo" in model: - # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613") - elif "gpt-4" in model: - # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return count_chat_completion_tokens(messages, model="gpt-4-0613") - else: - raise NotImplementedError( - f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" - ) - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens - - -def count_completion_tokens(messages: list, model: str) -> int: - - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - - text = "" - - for message in messages: - text += message['role'] + message['content'] + "\n" - - text += "assistant: " - - return len(encoding.encode(text)) - - -def count_tokens(messages: list, model: str): - - if model in CHAT_COMPLETION_MODELS: - return count_chat_completion_tokens(messages, model) - elif model in COMPLETION_MODELS: - return count_completion_tokens(messages, model) - raise ValueError("不支持模型[{}],请检查配置文件".format(model)) diff --git a/pkg/openai/session.py b/pkg/openai/session.py deleted file mode 100644 index 19a69ea2..00000000 --- a/pkg/openai/session.py +++ /dev/null @@ -1,504 +0,0 @@ -"""主线使用的会话管理模块 - -每个人、每个群单独一个session,session内部保留了对话的上下文, -""" - -import logging -import threading -import time -import json - -from ..openai import manager as openai_manager -from ..openai import modelmgr as openai_modelmgr -from ..database import manager as database_manager -from ..utils import context as context - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models - -# 运行时保存的所有session -sessions = {} - - -class SessionOfflineStatus: - ON_GOING = 'on_going' - EXPLICITLY_CLOSED = 'explicitly_closed' - - -# 从数据加载session -def load_sessions(): - """从数据库加载sessions""" - - global sessions - - db_inst = context.get_database_manager() - - session_data = db_inst.load_valid_sessions() - - for session_name in session_data: - logging.debug('加载session: {}'.format(session_name)) - - temp_session = Session(session_name) - temp_session.name = session_name - temp_session.create_timestamp = session_data[session_name]['create_timestamp'] - temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] - - temp_session.prompt = json.loads(session_data[session_name]['prompt']) - temp_session.token_counts = json.loads(session_data[session_name]['token_counts']) - - temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \ - session_data[session_name]['default_prompt'] else [] - - sessions[session_name] = temp_session - - -# 获取指定名称的session,如果不存在则创建一个新的 -def get_session(session_name: str) -> 'Session': - global sessions - if session_name not in sessions: - sessions[session_name] = Session(session_name) - return sessions[session_name] - - -def dump_session(session_name: str): - global sessions - if session_name in sessions: - assert isinstance(sessions[session_name], Session) - sessions[session_name].persistence() - del sessions[session_name] - - -# 通用的OpenAI API交互session -# session内部保留了对话的上下文, -# 收到用户消息后,将上下文提交给OpenAI API生成回复 -class Session: - name = '' - - prompt = [] - """使用list来保存会话中的回合""" - - default_prompt = [] - """本session的默认prompt""" - - create_timestamp = 0 - """会话创建时间""" - - last_interact_timestamp = 0 - """上次交互(产生回复)时间""" - - just_switched_to_exist_session = False - - response_lock = None - - # 加锁 - def acquire_response_lock(self): - logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock)) - self.response_lock.acquire() - logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock)) - - # 释放锁 - def release_response_lock(self): - if self.response_lock.locked(): - logging.debug('{},lock release,{}'.format(self.name, self.response_lock)) - self.response_lock.release() - logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) - - # 从配置文件获取会话预设信息 - def get_default_prompt(self, use_default: str = None): - import pkg.openai.dprompt as dprompt - - if use_default is None: - use_default = dprompt.mode_inst().get_using_name() - - current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default) - return current_default_prompt - - def __init__(self, name: str): - self.name = name - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.prompt = [] - self.token_counts = [] - self.schedule() - - self.response_lock = threading.Lock() - - self.default_prompt = self.get_default_prompt() - logging.debug("prompt is: {}".format(self.default_prompt)) - - # 设定检查session最后一次对话是否超过过期时间的计时器 - def schedule(self): - threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start() - - # 检查session是否已经过期 - def expire_check_timer_loop(self, create_timestamp: int): - global sessions - while True: - time.sleep(60) - - # 不是此session已更换,退出 - if self.create_timestamp != create_timestamp or self not in sessions.values(): - return - - config = context.get_config_manager().data - if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']: - logging.info('session {} 已过期'.format(self.name)) - - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self, - 'session_expire_time': config['session_expire_time'] - } - event = plugin_host.emit(plugin_models.SessionExpired, **args) - if event.is_prevented_default(): - return - - self.reset(expired=True, schedule_new=False) - - # 删除此session - del sessions[self.name] - return - - # 请求回复 - # 这个函数是阻塞的 - def query(self, text: str=None) -> tuple[str, str, list[str]]: - """向session中添加一条消息,返回接口回复 - - Args: - text (str): 用户消息 - - Returns: - tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表) - """ - - self.last_interact_timestamp = int(time.time()) - - # 触发插件事件 - if not self.prompt: - args = { - 'session_name': self.name, - 'session': self, - 'default_prompt': self.default_prompt, - } - - event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args) - if event.is_prevented_default(): - return None, None, None - - config = context.get_config_manager().data - max_length = config['prompt_submit_length'] - - local_default_prompt = self.default_prompt.copy() - local_prompt = self.prompt.copy() - - # 触发PromptPreProcessing事件 - args = { - 'session_name': self.name, - 'default_prompt': self.default_prompt, - 'prompt': self.prompt, - 'text_message': text, - } - - event = plugin_host.emit(plugin_models.PromptPreProcessing, **args) - - if event.get_return_value('default_prompt') is not None: - local_default_prompt = event.get_return_value('default_prompt') - - if event.get_return_value('prompt') is not None: - local_prompt = event.get_return_value('prompt') - - if event.get_return_value('text_message') is not None: - text = event.get_return_value('text_message') - - # 裁剪messages到合适长度 - prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt) - - res_text = "" - - pending_msgs = [] - - total_tokens = 0 - - finish_reason: str = "" - - funcs = [] - - trace_func_calls = config['trace_function_calls'] - botmgr = context.get_qqbot_manager() - - session_name_spt: list[str] = self.name.split("_") - - pending_res_text = "" - - start_time = time.time() - - # TODO 对不起,我知道这样非常非常屎山,但我之后会重构的 - for resp in context.get_openai_manager().request_completion(prompts): - - if pending_res_text != "": - botmgr.adapter.send_message( - session_name_spt[0], - session_name_spt[1], - pending_res_text - ) - pending_res_text = "" - - finish_reason = resp['choices'][0]['finish_reason'] - - if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应 - - if not trace_func_calls: - res_text += resp['choices'][0]['message']['content'] - else: - res_text = resp['choices'][0]['message']['content'] - pending_res_text = resp['choices'][0]['message']['content'] - - total_tokens += resp['usage']['total_tokens'] - - msg = { - "role": "assistant", - "content": resp['choices'][0]['message']['content'] - } - - if 'function_call' in resp['choices'][0]['message']: - msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call']) - - pending_msgs.append(msg) - - if resp['choices'][0]['message']['type'] == 'function_call': - # self.prompt.append( - # { - # "role": "assistant", - # "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call']) - # } - # ) - if trace_func_calls: - botmgr.adapter.send_message( - session_name_spt[0], - session_name_spt[1], - "调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..." - ) - - total_tokens += resp['usage']['total_tokens'] - elif resp['choices'][0]['message']['type'] == 'function_return': - # self.prompt.append( - # { - # "role": "function", - # "name": resp['choices'][0]['message']['function_name'], - # "content": json.dumps(resp['choices'][0]['message']['content']) - # } - # ) - - # total_tokens += resp['usage']['total_tokens'] - funcs.append( - resp['choices'][0]['message']['function_name'] - ) - pass - - # 向API请求补全 - # message, total_token = pkg.utils.context.get_openai_manager().request_completion( - # prompts, - # ) - - # 成功获取,处理回复 - # res_test = message - res_ans = res_text.strip() - - # 将此次对话的双方内容加入到prompt中 - # self.prompt.append({'role': 'user', 'content': text}) - # self.prompt.append({'role': 'assistant', 'content': res_ans}) - if text: - self.prompt.append({'role': 'user', 'content': text}) - # 添加pending_msgs - self.prompt += pending_msgs - - # 向token_counts中添加本回合的token数量 - # self.token_counts.append(total_tokens-total_token_before_query) - # logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts)) - - if self.just_switched_to_exist_session: - self.just_switched_to_exist_session = False - self.set_ongoing() - - # 上报使用量数据 - session_type = session_name_spt[0] - session_id = session_name_spt[1] - - ability_provider = "QChatGPT.Text" - usage = total_tokens - model_name = context.get_config_manager().data['completion_api_params']['model'] - response_seconds = int(time.time() - start_time) - retry_times = -1 # 暂不记录 - - context.get_center_v2_api().usage.post_query_record( - session_type=session_type, - session_id=session_id, - query_ability_provider=ability_provider, - usage=usage, - model_name=model_name, - response_seconds=response_seconds, - retry_times=retry_times - ) - - return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs - - # 删除上一回合并返回上一回合的问题 - def undo(self) -> str: - self.last_interact_timestamp = int(time.time()) - - # 删除最后两个消息 - if len(self.prompt) < 2: - raise Exception('之前无对话,无法撤销') - - question = self.prompt[-2]['content'] - self.prompt = self.prompt[:-2] - self.token_counts = self.token_counts[:-1] - - # 返回上一回合的问题 - return question - - # 构建对话体 - def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]: - """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens - - :return: (新的prompt, 新的token_counts) - """ - - # 最终由三个部分组成 - # - default_prompt 情景预设固定值 - # - changable_prompts 可变部分, 此会话中的历史对话回合 - # - current_question 当前问题 - - # 包装目前的对话回合内容 - changable_prompts = [] - - use_model = context.get_config_manager().data['completion_api_params']['model'] - - ptr = len(prompt) - 1 - - # 直接从后向前扫描拼接,不管是否是整回合 - while ptr >= 0: - if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: - break - - changable_prompts.insert(0, prompt[ptr]) - - ptr -= 1 - - # 将default_prompt和changable_prompts合并 - result_prompt = default_prompt + changable_prompts - - # 添加当前问题 - if msg: - result_prompt.append( - { - 'role': 'user', - 'content': msg - } - ) - - logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4))) - - return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model) - - # 持久化session - def persistence(self): - if self.prompt == self.get_default_prompt(): - return - - db_inst = context.get_database_manager() - - name_spt = self.name.split('_') - - subject_type = name_spt[0] - subject_number = int(name_spt[1]) - - db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts)) - - # 重置session - def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False): - if self.prompt: - self.persistence() - if explicit: - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self - } - - # 此事件不支持阻止默认行为 - _ = plugin_host.emit(plugin_models.SessionExplicitReset, **args) - - context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) - - if expired: - context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - - if not persist: # 不要求保持default prompt - self.default_prompt = self.get_default_prompt(use_prompt) - self.prompt = [] - self.token_counts = [] - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.just_switched_to_exist_session = False - - # self.response_lock = threading.Lock() - - if schedule_new: - self.schedule() - - # 将本session的数据库状态设置为on_going - def set_ongoing(self): - context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) - - # 切换到上一个session - def last_session(self): - last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp) - if last_one is None: - return None - else: - self.persistence() - - self.create_timestamp = last_one['create_timestamp'] - self.last_interact_timestamp = last_one['last_interact_timestamp'] - - self.prompt = json.loads(last_one['prompt']) - self.token_counts = json.loads(last_one['token_counts']) - - self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else [] - - self.just_switched_to_exist_session = True - return self - - # 切换到下一个session - def next_session(self): - next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp) - if next_one is None: - return None - else: - self.persistence() - - self.create_timestamp = next_one['create_timestamp'] - self.last_interact_timestamp = next_one['last_interact_timestamp'] - - self.prompt = json.loads(next_one['prompt']) - self.token_counts = json.loads(next_one['token_counts']) - - self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else [] - - self.just_switched_to_exist_session = True - return self - - def list_history(self, capacity: int = 10, page: int = 0): - return context.get_database_manager().list_history(self.name, capacity, page) - - def delete_history(self, index: int) -> bool: - return context.get_database_manager().delete_history(self.name, index) - - def delete_all_history(self) -> bool: - return context.get_database_manager().delete_all_history(self.name) - - def draw_image(self, prompt: str): - return context.get_openai_manager().request_image(prompt) diff --git a/pkg/utils/center/__init__.py b/pkg/pipeline/__init__.py similarity index 100% rename from pkg/utils/center/__init__.py rename to pkg/pipeline/__init__.py diff --git a/pkg/utils/center/groups/__init__.py b/pkg/pipeline/bansess/__init__.py similarity index 100% rename from pkg/utils/center/groups/__init__.py rename to pkg/pipeline/bansess/__init__.py diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py new file mode 100644 index 00000000..f56babe2 --- /dev/null +++ b/pkg/pipeline/bansess/bansess.py @@ -0,0 +1,45 @@ +from __future__ import annotations +import re + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class('BanSessionCheckStage') +class BanSessionCheckStage(stage.PipelineStage): + + async def initialize(self): + pass + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + + found = False + + mode = self.ap.pipeline_cfg.data['access-control']['mode'] + + sess_list = self.ap.pipeline_cfg.data['access-control'][mode] + + if (query.launcher_type == 'group' and 'group_*' in sess_list) \ + or (query.launcher_type == 'person' and 'person_*' in sess_list): + found = True + else: + for sess in sess_list: + if sess == f"{query.launcher_type}_{query.launcher_id}": + found = True + break + + result = False + + if mode == 'blacklist': + result = found + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT, + new_query=query, + debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else '' + ) diff --git a/plugins/__init__.py b/pkg/pipeline/cntfilter/__init__.py similarity index 100% rename from plugins/__init__.py rename to pkg/pipeline/cntfilter/__init__.py diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py new file mode 100644 index 00000000..9982a51e --- /dev/null +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr +from . import filter, entities as filter_entities +from .filters import cntignore, banwords, baiduexamine + + +@stage.stage_class('PostContentFilterStage') +@stage.stage_class('PreContentFilterStage') +class ContentFilterStage(stage.PipelineStage): + + filter_chain: list[filter.ContentFilter] + + def __init__(self, ap: app.Application): + self.filter_chain = [] + super().__init__(ap) + + async def initialize(self): + self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + if self.ap.pipeline_cfg.data['check-sensitive-words']: + self.filter_chain.append(banwords.BanWordFilter(self.ap)) + + if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: + self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + + for filter in self.filter_chain: + await filter.initialize() + + async def _pre_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm前处理消息 + 只要有一个不通过就不放行,只放行 PASS 的消息 + """ + if not self.ap.pipeline_cfg.data['income-msg-check']: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + for filter in self.filter_chain: + if filter_entities.EnableStage.PRE in filter.enable_stages: + result = await filter.process(message) + + if result.level in [ + filter_entities.ResultLevel.BLOCK, + filter_entities.ResultLevel.MASKED + ]: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 + message = result.replacement + + query.message_chain = mirai.MessageChain( + mirai.Plain(message) + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def _post_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm后处理响应 + 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter + """ + if message is None: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + message = message.strip() + for filter in self.filter_chain: + if filter_entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) + + if result.level == filter_entities.ResultLevel.BLOCK: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + filter_entities.ResultLevel.PASS, + filter_entities.ResultLevel.MASKED + ]: + message = result.replacement + + query.resp_messages[-1].content = message + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + """处理 + """ + if stage_inst_name == 'PreContentFilterStage': + return await self._pre_process( + str(query.message_chain).strip(), + query + ) + elif stage_inst_name == 'PostContentFilterStage': + return await self._post_process( + query.resp_messages[-1].content, + query + ) + else: + raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py new file mode 100644 index 00000000..7ab05675 --- /dev/null +++ b/pkg/pipeline/cntfilter/entities.py @@ -0,0 +1,64 @@ + +import typing +import enum + +import pydantic + + +class ResultLevel(enum.Enum): + """结果等级""" + PASS = enum.auto() + """通过""" + + WARN = enum.auto() + """警告""" + + MASKED = enum.auto() + """已掩去""" + + BLOCK = enum.auto() + """阻止""" + + +class EnableStage(enum.Enum): + """启用阶段""" + PRE = enum.auto() + """预处理""" + + POST = enum.auto() + """后处理""" + + +class FilterResult(pydantic.BaseModel): + level: ResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """不通过时,用户提示消息""" + + console_notice: str + """不通过时,控制台提示消息""" + + +class ManagerResultLevel(enum.Enum): + """处理器结果等级""" + CONTINUE = enum.auto() + """继续""" + + INTERRUPT = enum.auto() + """中断""" + +class FilterManagerResult(pydantic.BaseModel): + + level: ManagerResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """用户提示消息""" + + console_notice: str + """控制台提示消息""" diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py new file mode 100644 index 00000000..57792145 --- /dev/null +++ b/pkg/pipeline/cntfilter/filter.py @@ -0,0 +1,34 @@ +# 内容过滤器的抽象类 +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class ContentFilter(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + @property + def enable_stages(self): + """启用的阶段 + """ + return [ + entities.EnableStage.PRE, + entities.EnableStage.POST + ] + + async def initialize(self): + """初始化过滤器 + """ + pass + + @abc.abstractmethod + async def process(self, message: str) -> entities.FilterResult: + """处理消息 + """ + raise NotImplementedError diff --git a/tests/__init__.py b/pkg/pipeline/cntfilter/filters/__init__.py similarity index 100% rename from tests/__init__.py rename to pkg/pipeline/cntfilter/filters/__init__.py diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py new file mode 100644 index 00000000..f72fe960 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import aiohttp + +from .. import entities +from .. import filter as filter_model + + +BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" +BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" + + +class BaiduCloudExamine(filter_model.ContentFilter): + """百度云内容审核""" + + async def _get_token(self) -> str: + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'], + "client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'] + } + ) as resp: + return (await resp.json())['access_token'] + + async def process(self, message: str) -> entities.FilterResult: + + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, + data=f"text={message}".encode('utf-8') + ) as resp: + result = await resp.json() + + if "error_code" in result: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='', + console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + ) + else: + conclusion = result["conclusion"] + + if conclusion in ("合规"): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f"百度云判定结果:{conclusion}" + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice="消息中存在不合适的内容, 请修改", + console_notice=f"百度云判定结果:{conclusion}" + ) diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py new file mode 100644 index 00000000..587f81c3 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import re + +from .. import filter as filter_model +from .. import entities +from ....config import manager as cfg_mgr + + +class BanWordFilter(filter_model.ContentFilter): + """根据内容禁言""" + + sensitive: cfg_mgr.ConfigManager + + async def initialize(self): + self.sensitive = await cfg_mgr.load_json_config( + "data/config/sensitive-words.json", + "templates/sensitive-words.json" + ) + + async def process(self, message: str) -> entities.FilterResult: + found = False + + for word in self.sensitive.data['words']: + match = re.findall(word, message) + + if len(match) > 0: + found = True + + for i in range(len(match)): + if self.sensitive.data['mask_word'] == "": + message = message.replace( + match[i], self.sensitive.data['mask'] * len(match[i]) + ) + else: + message = message.replace( + match[i], self.sensitive.data['mask_word'] + ) + + return entities.FilterResult( + level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, + replacement=message, + user_notice='消息中存在不合适的内容, 请修改' if found else '', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py new file mode 100644 index 00000000..92fe94e8 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import re + +from .. import entities +from .. import filter as filter_model + + +class ContentIgnore(filter_model.ContentFilter): + """根据内容忽略消息""" + + @property + def enable_stages(self): + return [ + entities.EnableStage.PRE, + ] + + async def process(self, message: str) -> entities.FilterResult: + if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']: + for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']: + if message.startswith(rule): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + ) + + if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']: + for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']: + if re.search(rule, message): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + ) + + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py new file mode 100644 index 00000000..e8cfc427 --- /dev/null +++ b/pkg/pipeline/entities.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import enum +import typing + +import pydantic +import mirai +import mirai.models.message as mirai_message + +from ..core import entities + + +class ResultType(enum.Enum): + + CONTINUE = enum.auto() + """继续流水线""" + + INTERRUPT = enum.auto() + """中断流水线""" + + +class StageProcessResult(pydantic.BaseModel): + + result_type: ResultType + + new_query: entities.Query + + user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给用户""" + + # TODO delete + # admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给管理员""" + + console_notice: typing.Optional[str] = '' + """只要设置了就会输出到控制台""" + + debug_notice: typing.Optional[str] = '' + + error_notice: typing.Optional[str] = '' diff --git a/tests/plugin_examples/auto_approval/__init__.py b/pkg/pipeline/longtext/__init__.py similarity index 100% rename from tests/plugin_examples/auto_approval/__init__.py rename to pkg/pipeline/longtext/__init__.py diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py new file mode 100644 index 00000000..ab70732b --- /dev/null +++ b/pkg/pipeline/longtext/longtext.py @@ -0,0 +1,59 @@ +from __future__ import annotations +import os +import traceback + +from PIL import Image, ImageDraw, ImageFont +from mirai.models.message import MessageComponent, Plain, MessageChain + +from ...core import app +from . import strategy +from .strategies import image, forward +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("LongTextProcessStage") +class LongTextProcessStage(stage.PipelineStage): + + strategy_impl: strategy.LongTextStrategy + + async def initialize(self): + config = self.ap.platform_cfg.data['long-text-process'] + if config['strategy'] == 'image': + use_font = config['font-path'] + try: + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + self.ap.logger.info("使用Windows自带字体:" + use_font) + config['font-path'] = use_font + else: + self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + + self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + except: + traceback.print_exc() + self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + + self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + + if config['strategy'] == 'image': + self.strategy_impl = image.Text2ImageStrategy(self.ap) + elif config['strategy'] == 'forward': + self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + await self.strategy_impl.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: + query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) diff --git a/tests/plugin_examples/cmdcn/__init__.py b/pkg/pipeline/longtext/strategies/__init__.py similarity index 100% rename from tests/plugin_examples/cmdcn/__init__.py rename to pkg/pipeline/longtext/strategies/__init__.py diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py new file mode 100644 index 00000000..cfab49d9 --- /dev/null +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -0,0 +1,63 @@ +# 转发消息组件 +from __future__ import annotations +import typing + +from mirai.models import MessageChain +from mirai.models.message import MessageComponent, ForwardMessageNode +from mirai.models.base import MiraiBaseModel + +from .. import strategy as strategy_model +from ....core import entities as core_entities + + +class ForwardMessageDiaplay(MiraiBaseModel): + title: str = "群聊的聊天记录" + brief: str = "[聊天记录]" + source: str = "聊天记录" + preview: typing.List[str] = [] + summary: str = "查看x条转发消息" + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + display: ForwardMessageDiaplay + """显示信息""" + node_list: typing.List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +class ForwardComponentStrategy(strategy_model.LongTextStrategy): + + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + display = ForwardMessageDiaplay( + title="群聊的聊天记录", + brief="[聊天记录]", + source="聊天记录", + preview=["QQ用户: "+message], + summary="查看1条转发消息" + ) + + node_list = [ + ForwardMessageNode( + sender_id=query.adapter.bot_account_id, + sender_name='QQ用户', + message_chain=MessageChain([message]) + ) + ] + + forward = Forward( + display=display, + node_list=node_list + ) + + return [forward] diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py new file mode 100644 index 00000000..af34f4e6 --- /dev/null +++ b/pkg/pipeline/longtext/strategies/image.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import typing +import os +import base64 +import time +import re + +from PIL import Image, ImageDraw, ImageFont + +from mirai.models import MessageChain, Image as ImageComponent +from mirai.models.message import MessageComponent + +from .. import strategy as strategy_model +from ....core import entities as core_entities + + +class Text2ImageStrategy(strategy_model.LongTextStrategy): + + text_render_font: ImageFont.FreeTypeFont + + async def initialize(self): + self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") + + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + img_path = self.text_to_image( + text_str=message, + save_as='temp/{}.png'.format(int(time.time())) + ) + + compressed_path, size = self.compress_image( + img_path, + outfile="temp/{}_compressed.png".format(int(time.time())) + ) + + with open(compressed_path, 'rb') as f: + img = f.read() + + b64 = base64.b64encode(img) + + # 删除图片 + os.remove(img_path) + + if os.path.exists(compressed_path): + os.remove(compressed_path) + + return [ + ImageComponent( + base64=b64.decode('utf-8'), + ) + ] + + def indexNumber(self, path=''): + """ + 查找字符串中数字所在串中的位置 + :param path:目标字符串 + :return:: : [['1', 16], ['2', 35], ['1', 51]] + """ + kv = [] + nums = [] + beforeDatas = re.findall('[\d]+', path) + for num in beforeDatas: + indexV = [] + times = path.count(num) + if times > 1: + if num not in nums: + indexs = re.finditer(num, path) + for index in indexs: + iV = [] + i = index.span()[0] + iV.append(num) + iV.append(i) + kv.append(iV) + nums.append(num) + else: + index = path.find(num) + indexV.append(num) + indexV.append(index) + kv.append(indexV) + # 根据数字位置排序 + indexSort = [] + resultIndex = [] + for vi in kv: + indexSort.append(vi[1]) + indexSort.sort() + for i in indexSort: + for v in kv: + if i == v[1]: + resultIndex.append(v) + return resultIndex + + + def get_size(self, file): + # 获取文件大小:KB + size = os.path.getsize(file) + return size / 1024 + + + def get_outfile(self, infile, outfile): + if outfile: + return outfile + dir, suffix = os.path.splitext(infile) + outfile = '{}-out{}'.format(dir, suffix) + return outfile + + + def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): + """不改变图片尺寸压缩到指定大小 + :param infile: 压缩源文件 + :param outfile: 压缩文件保存地址 + :param mb: 压缩目标,KB + :param step: 每次调整的压缩比率 + :param quality: 初始压缩比率 + :return: 压缩文件地址,压缩文件大小 + """ + o_size = self.get_size(infile) + if o_size <= kb: + return infile, o_size + outfile = self.get_outfile(infile, outfile) + while o_size > kb: + im = Image.open(infile) + im.save(outfile, quality=quality) + if quality - step < 0: + break + quality -= step + o_size = self.get_size(outfile) + return outfile, self.get_size(outfile) + + + def text_to_image(self, text_str: str, save_as="temp.png", width=800): + + text_str = text_str.replace("\t", " ") + + # 分行 + lines = text_str.split('\n') + + # 计算并分割 + final_lines = [] + + text_width = width-80 + + self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) + for line in lines: + # 如果长了就分割 + line_width = self.text_render_font.getlength(line) + self.ap.logger.debug("line_width: {}".format(line_width)) + if line_width < text_width: + final_lines.append(line) + continue + else: + rest_text = line + while True: + # 分割最前面的一行 + point = int(len(rest_text) * (text_width / line_width)) + + # 检查断点是否在数字中间 + numbers = self.indexNumber(rest_text) + + for number in numbers: + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + point = number[1] + break + + final_lines.append(rest_text[:point]) + rest_text = rest_text[point:] + line_width = self.text_render_font.getlength(rest_text) + if line_width < text_width: + final_lines.append(rest_text) + break + else: + continue + # 准备画布 + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) + draw = ImageDraw.Draw(img, mode='RGBA') + + self.ap.logger.debug("正在绘制图片...") + # 绘制正文 + line_number = 0 + offset_x = 20 + offset_y = 30 + for final_line in final_lines: + draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) + # 遍历此行,检查是否有emoji + idx_in_line = 0 + for ch in final_line: + # 检查字符占位宽 + char_code = ord(ch) + if char_code >= 127: + idx_in_line += 1 + else: + idx_in_line += 0.5 + + line_number += 1 + + self.ap.logger.debug("正在保存图片...") + img.save(save_as) + + return save_as diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py new file mode 100644 index 00000000..a1f8a94f --- /dev/null +++ b/pkg/pipeline/longtext/strategy.py @@ -0,0 +1,23 @@ +from __future__ import annotations +import abc +import typing + +import mirai +from mirai.models.message import MessageComponent + +from ...core import app +from ...core import entities as core_entities + + +class LongTextStrategy(metaclass=abc.ABCMeta): + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + return [] diff --git a/tests/plugin_examples/hello_plugin/__init__.py b/pkg/pipeline/preproc/__init__.py similarity index 100% rename from tests/plugin_examples/hello_plugin/__init__.py rename to pkg/pipeline/preproc/__init__.py diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py new file mode 100644 index 00000000..ad0d7ff6 --- /dev/null +++ b/pkg/pipeline/preproc/preproc.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...provider import entities as llm_entities +from ...plugin import events + + +@stage.stage_class("PreProcessor") +class PreProcessor(stage.PipelineStage): + """预处理器 + """ + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + session = await self.ap.sess_mgr.get_session(query) + + conversation = await self.ap.sess_mgr.get_conversation(session) + + # 从会话取出消息和情景预设到query + query.session = session + query.prompt = conversation.prompt.copy() + query.messages = conversation.messages.copy() + + query.user_message = llm_entities.Message( + role='user', + content=str(query.message_chain).strip() + ) + + query.use_model = conversation.use_model + + query.use_funcs = conversation.use_funcs + + # =========== 触发事件 PromptPreProcessing + session = query.session + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PromptPreProcessing( + session_name=f'{session.launcher_type.value}_{session.launcher_id}', + default_prompt=query.prompt.messages, + prompt=query.messages, + query=query + ) + ) + + query.prompt.messages = event_ctx.event.default_prompt + query.messages = event_ctx.event.prompt + + # 根据模型max_tokens剪裁 + max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens']) + + test_messages = query.prompt.messages + query.messages + [query.user_message] + + while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens: + # 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错 + if len(query.prompt.messages) == 0: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice='输入内容过长,请减少情景预设或者输入内容长度', + console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)' + ) + + query.messages.pop(0) # pop第一个肯定是role=user的 + # 继续pop到第二个role=user前一个 + while len(query.messages) > 0 and query.messages[0].role != 'user': + query.messages.pop(0) + + test_messages = query.prompt.messages + query.messages + [query.user_message] + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) diff --git a/tests/plugin_examples/urlikethisijustsix/__init__.py b/pkg/pipeline/process/__init__.py similarity index 100% rename from tests/plugin_examples/urlikethisijustsix/__init__.py rename to pkg/pipeline/process/__init__.py diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py new file mode 100644 index 00000000..6d19e039 --- /dev/null +++ b/pkg/pipeline/process/handler.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import abc + +from ...core import app +from ...core import entities as core_entities +from .. import entities + + +class MessageHandler(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def handle( + self, + query: core_entities.Query, + ) -> entities.StageProcessResult: + raise NotImplementedError diff --git a/tests/token_test/__init__.py b/pkg/pipeline/process/handlers/__init__.py similarity index 100% rename from tests/token_test/__init__.py rename to pkg/pipeline/process/handlers/__init__.py diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py new file mode 100644 index 00000000..26c99a2e --- /dev/null +++ b/pkg/pipeline/process/handlers/chat.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import typing +import time +import traceback + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities +from ....provider import entities as llm_entities +from ....plugin import events + + +class ChatMessageHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + # 取session + # 取conversation + # 调API + # 生成器 + + # 触发插件事件 + event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + text_message=str(query.message_chain), + query=query + ) + ) + + if event_ctx.is_prevented_default(): + if event_ctx.event.reply is not None: + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + + if not self.ap.provider_cfg.data['enable-chat']: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + ) + + if event_ctx.event.alter is not None: + query.message_chain = mirai.MessageChain([ + mirai.Plain(event_ctx.event.alter) + ]) + + query.messages.append( + query.user_message + ) + + text_length = 0 + + start_time = time.time() + + try: + + async for result in query.use_model.requester.request(query): + query.resp_messages.append(result) + + if result.content is not None: + text_length += len(result.content) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + except Exception as e: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}', + error_notice=f'{e}', + debug_notice=traceback.format_exc() + ) + finally: + query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) + + await self.ap.ctr_mgr.usage.post_query_record( + session_type=query.session.launcher_type.value, + session_id=str(query.session.launcher_id), + query_ability_provider="QChatGPT.Chat", + usage=text_length, + model_name=query.use_model.name, + response_seconds=int(time.time() - start_time), + retry_times=-1, + ) \ No newline at end of file diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py new file mode 100644 index 00000000..d5873e38 --- /dev/null +++ b/pkg/pipeline/process/handlers/command.py @@ -0,0 +1,117 @@ +from __future__ import annotations +import typing + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities +from ....provider import entities as llm_entities +from ....plugin import events + + +class CommandHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + + event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent + + + privilege = 1 + + if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']: + privilege = 2 + + spt = str(query.message_chain).strip().split(' ') + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + command=spt[0], + params=spt[1:] if len(spt) > 1 else [], + text_message=str(query.message_chain), + is_admin=(privilege==2), + query=query + ) + ) + + if event_ctx.is_prevented_default(): + + if event_ctx.event.reply is not None: + mc = mirai.MessageChain(event_ctx.event.reply) + + query.resp_messages.append( + llm_entities.Message( + role='command', + content=str(mc), + ) + ) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + + else: + + if event_ctx.event.alter is not None: + query.message_chain = mirai.MessageChain([ + mirai.Plain(event_ctx.event.alter) + ]) + + session = await self.ap.sess_mgr.get_session(query) + + command_text = str(query.message_chain).strip()[1:] + + async for ret in self.ap.cmd_mgr.execute( + command_text=command_text, + query=query, + session=session + ): + if ret.error is not None: + # query.resp_message_chain = mirai.MessageChain([ + # mirai.Plain(str(ret.error)) + # ]) + query.resp_messages.append( + llm_entities.Message( + role='command', + content=str(ret.error), + ) + ) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif ret.text is not None: + # query.resp_message_chain = mirai.MessageChain([ + # mirai.Plain(ret.text) + # ]) + query.resp_messages.append( + llm_entities.Message( + role='command', + content=ret.text, + ) + ) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py new file mode 100644 index 00000000..c24fdac2 --- /dev/null +++ b/pkg/pipeline/process/process.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from ...core import app, entities as core_entities +from . import handler +from .handlers import chat, command +from .. import entities +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("MessageProcessor") +class Processor(stage.PipelineStage): + + cmd_handler: handler.MessageHandler + + chat_handler: handler.MessageHandler + + async def initialize(self): + self.cmd_handler = command.CommandHandler(self.ap) + self.chat_handler = chat.ChatMessageHandler(self.ap) + + await self.cmd_handler.initialize() + await self.chat_handler.initialize() + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + message_text = str(query.message_chain).strip() + + self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}") + + if message_text.startswith('!') or message_text.startswith('!'): + return self.cmd_handler.handle(query) + else: + return self.chat_handler.handle(query) diff --git a/tests/token_test/token_count.py b/pkg/pipeline/ratelimit/__init__.py similarity index 100% rename from tests/token_test/token_count.py rename to pkg/pipeline/ratelimit/__init__.py diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py new file mode 100644 index 00000000..b6d9ba7b --- /dev/null +++ b/pkg/pipeline/ratelimit/algo.py @@ -0,0 +1,24 @@ +from __future__ import annotations +import abc + +from ...core import app + + +class ReteLimitAlgo(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + raise NotImplementedError + + @abc.abstractmethod + async def release_access(self, launcher_type: str, launcher_id: int): + raise NotImplementedError + \ No newline at end of file diff --git a/pkg/pipeline/ratelimit/algos/__init__.py b/pkg/pipeline/ratelimit/algos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py new file mode 100644 index 00000000..bb69b0dd --- /dev/null +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -0,0 +1,85 @@ +# 固定窗口算法 +from __future__ import annotations + +import asyncio +import time + +from .. import algo + + +class SessionContainer: + + wait_lock: asyncio.Lock + + records: dict[int, int] + """访问记录,key为每分钟的起始时间戳,value为访问次数""" + + def __init__(self): + self.wait_lock = asyncio.Lock() + self.records = {} + + +class FixedWindowAlgo(algo.ReteLimitAlgo): + + containers_lock: asyncio.Lock + """访问记录容器锁""" + + containers: dict[str, SessionContainer] + """访问记录容器,key为launcher_type launcher_id""" + + async def initialize(self): + self.containers_lock = asyncio.Lock() + self.containers = {} + + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + # 加锁,找容器 + container: SessionContainer = None + + session_name = f'{launcher_type}_{launcher_id}' + + async with self.containers_lock: + container = self.containers.get(session_name) + + if container is None: + container = SessionContainer() + self.containers[session_name] = container + + # 等待锁 + async with container.wait_lock: + # 获取当前时间戳 + now = int(time.time()) + + # 获取当前分钟的起始时间戳 + now = now - now % 60 + + # 获取当前分钟的访问次数 + count = container.records.get(now, 0) + + limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'] + + if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: + limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] + + # 如果访问次数超过了限制 + if count >= limitation: + if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': + return False + elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': + # 等待下一分钟 + await asyncio.sleep(60 - time.time() % 60) + + now = int(time.time()) + now = now - now % 60 + + if now not in container.records: + container.records = {} + container.records[now] = 1 + else: + # 访问次数加一 + container.records[now] = count + 1 + + # 返回True + return True + + async def release_access(self, launcher_type: str, launcher_id: int): + pass diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py new file mode 100644 index 00000000..cc8e4ac1 --- /dev/null +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing + +from .. import entities, stagemgr, stage +from . import algo +from .algos import fixedwin +from ...core import entities as core_entities + + +@stage.stage_class("RequireRateLimitOccupancy") +@stage.stage_class("ReleaseRateLimitOccupancy") +class RateLimit(stage.PipelineStage): + + algo: algo.ReteLimitAlgo + + async def initialize(self): + self.algo = fixedwin.FixedWindowAlgo(self.ap) + await self.algo.initialize() + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> typing.Union[ + entities.StageProcessResult, + typing.AsyncGenerator[entities.StageProcessResult, None], + ]: + """处理 + """ + if stage_inst_name == "RequireRateLimitOccupancy": + if await self.algo.require_access( + query.launcher_type.value, + query.launcher_id, + ): + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) + else: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息", + user_notice=f"请求数超过限速器设定值,已丢弃本消息。" + ) + elif stage_inst_name == "ReleaseRateLimitOccupancy": + await self.algo.release_access( + query.launcher_type, + query.launcher_id, + ) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) diff --git a/pkg/pipeline/respback/__init__.py b/pkg/pipeline/respback/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py new file mode 100644 index 00000000..10b2cbac --- /dev/null +++ b/pkg/pipeline/respback/respback.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import random +import asyncio + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("SendResponseBackStage") +class SendResponseBackStage(stage.PipelineStage): + """发送响应消息 + """ + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + """处理 + """ + random_delay = random.uniform(*self.ap.platform_cfg.data['force-delay']) + + self.ap.logger.debug( + "根据规则强制延迟回复: %s s", + random_delay + ) + + await asyncio.sleep(random_delay) + + await self.ap.im_mgr.send( + query.message_event, + query.resp_message_chain, + adapter=query.adapter + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/resprule/__init__.py b/pkg/pipeline/resprule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py new file mode 100644 index 00000000..ffee3081 --- /dev/null +++ b/pkg/pipeline/resprule/entities.py @@ -0,0 +1,9 @@ +import pydantic +import mirai + + +class RuleJudgeResult(pydantic.BaseModel): + + matching: bool = False + + replacement: mirai.MessageChain = None diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py new file mode 100644 index 00000000..8f418729 --- /dev/null +++ b/pkg/pipeline/resprule/resprule.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import mirai + +from ...core import app +from . import entities as rule_entities, rule +from .rules import atbot, prefix, regexp, random + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("GroupRespondRuleCheckStage") +class GroupRespondRuleCheckStage(stage.PipelineStage): + """群组响应规则检查器 + """ + + rule_matchers: list[rule.GroupRespondRule] + + async def initialize(self): + """初始化检查器 + """ + self.rule_matchers = [ + atbot.AtBotRule(self.ap), + prefix.PrefixRule(self.ap), + regexp.RegExpRule(self.ap), + random.RandomRespRule(self.ap), + ] + + for rule_matcher in self.rule_matchers: + await rule_matcher.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + + if query.launcher_type.value != 'group': + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + rules = self.ap.pipeline_cfg.data['respond-rules'] + + use_rule = rules['default'] + + if str(query.launcher_id) in use_rule: + use_rule = use_rule[str(query.launcher_id)] + + for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 + res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) + if res.matching: + query.message_chain = res.replacement + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py new file mode 100644 index 00000000..cde9ec3d --- /dev/null +++ b/pkg/pipeline/resprule/rule.py @@ -0,0 +1,32 @@ +from __future__ import annotations +import abc + +import mirai + +from ...core import app, entities as core_entities +from . import entities + + +class GroupRespondRule(metaclass=abc.ABCMeta): + """群组响应规则的抽象类 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict, + query: core_entities.Query + ) -> entities.RuleJudgeResult: + """判断消息是否匹配规则 + """ + raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/__init__.py b/pkg/pipeline/resprule/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py new file mode 100644 index 00000000..692bee72 --- /dev/null +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import mirai + +from .. import rule as rule_model +from .. import entities +from ....core import entities as core_entities + + +class AtBotRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict, + query: core_entities.Query + ) -> entities.RuleJudgeResult: + + if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: + message_chain.remove(mirai.At(query.adapter.bot_account_id)) + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement = message_chain + ) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py new file mode 100644 index 00000000..1b61c138 --- /dev/null +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -0,0 +1,32 @@ +import mirai + +from .. import rule as rule_model +from .. import entities +from ....core import entities as core_entities + + +class PrefixRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict, + query: core_entities.Query + ) -> entities.RuleJudgeResult: + prefixes = rule_dict['prefix'] + + for prefix in prefixes: + if message_text.startswith(prefix): + + return entities.RuleJudgeResult( + matching=True, + replacement=mirai.MessageChain([ + mirai.Plain(message_text[len(prefix):]) + ]), + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py new file mode 100644 index 00000000..185e03ec --- /dev/null +++ b/pkg/pipeline/resprule/rules/random.py @@ -0,0 +1,24 @@ +import random + +import mirai + +from .. import rule as rule_model +from .. import entities +from ....core import entities as core_entities + + +class RandomRespRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict, + query: core_entities.Query + ) -> entities.RuleJudgeResult: + random_rate = rule_dict['random'] + + return entities.RuleJudgeResult( + matching=random.random() < random_rate, + replacement=message_chain + ) \ No newline at end of file diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py new file mode 100644 index 00000000..4e39d432 --- /dev/null +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -0,0 +1,33 @@ +import re + +import mirai + +from .. import rule as rule_model +from .. import entities +from ....core import entities as core_entities + + +class RegExpRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict, + query: core_entities.Query + ) -> entities.RuleJudgeResult: + regexps = rule_dict['regexp'] + + for regexp in regexps: + match = re.match(regexp, message_text) + + if match: + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py new file mode 100644 index 00000000..56c092b5 --- /dev/null +++ b/pkg/pipeline/stage.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import abc +import typing + +from ..core import app, entities as core_entities +from . import entities + + +_stage_classes: dict[str, PipelineStage] = {} + + +def stage_class(name: str): + + def decorator(cls): + _stage_classes[name] = cls + return cls + + return decorator + + +class PipelineStage(metaclass=abc.ABCMeta): + """流水线阶段 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化 + """ + pass + + @abc.abstractmethod + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> typing.Union[ + entities.StageProcessResult, + typing.AsyncGenerator[entities.StageProcessResult, None], + ]: + """处理 + """ + raise NotImplementedError diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py new file mode 100644 index 00000000..c855d816 --- /dev/null +++ b/pkg/pipeline/stagemgr.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import pydantic + +from ..core import app +from . import stage +from .resprule import resprule +from .bansess import bansess +from .cntfilter import cntfilter +from .process import process +from .longtext import longtext +from .respback import respback +from .wrapper import wrapper +from .preproc import preproc +from .ratelimit import ratelimit + + +stage_order = [ + "GroupRespondRuleCheckStage", + "BanSessionCheckStage", + "PreContentFilterStage", + "PreProcessor", + "RequireRateLimitOccupancy", + "MessageProcessor", + "ReleaseRateLimitOccupancy", + "PostContentFilterStage", + "ResponseWrapper", + "LongTextProcessStage", + "SendResponseBackStage", +] + + +class StageInstContainer(): + """阶段实例容器 + """ + + inst_name: str + + inst: stage.PipelineStage + + def __init__(self, inst_name: str, inst: stage.PipelineStage): + self.inst_name = inst_name + self.inst = inst + + +class StageManager: + ap: app.Application + + stage_containers: list[StageInstContainer] + + def __init__(self, ap: app.Application): + self.ap = ap + + self.stage_containers = [] + + async def initialize(self): + """初始化 + """ + + for name, cls in stage._stage_classes.items(): + self.stage_containers.append(StageInstContainer( + inst_name=name, + inst=cls(self.ap) + )) + + for stage_containers in self.stage_containers: + await stage_containers.inst.initialize() + + # 按照 stage_order 排序 + self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name)) diff --git a/pkg/pipeline/wrapper/__init__.py b/pkg/pipeline/wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py new file mode 100644 index 00000000..0278b603 --- /dev/null +++ b/pkg/pipeline/wrapper/wrapper.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import typing + +import mirai + +from ...core import app, entities as core_entities +from .. import entities +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr +from ...plugin import events + + +@stage.stage_class("ResponseWrapper") +class ResponseWrapper(stage.PipelineStage): + + async def initialize(self): + pass + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + + if query.resp_messages[-1].role == 'command': + query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + + if query.resp_messages[-1].role == 'assistant': + result = query.resp_messages[-1] + session = await self.ap.sess_mgr.get_session(query) + + reply_text = '' + + if result.content is not None: # 有内容 + reply_text = result.content + + # ============= 触发插件事件 =============== + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + if result.tool_calls is not None: # 有函数调用 + + function_names = [tc.function.name for tc in result.tool_calls] + + reply_text = f'调用函数 {".".join(function_names)}...' + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + if self.ap.platform_cfg.data['track-function-calls']: + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/platform/__init__.py b/pkg/platform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/adapter.py b/pkg/platform/adapter.py similarity index 78% rename from pkg/qqbot/adapter.py rename to pkg/platform/adapter.py index 784d8ae3..38c31fe2 100644 --- a/pkg/qqbot/adapter.py +++ b/pkg/platform/adapter.py @@ -1,15 +1,40 @@ +from __future__ import annotations + # MessageSource的适配器 import typing +import abc import mirai +from ..core import app + + +preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = [] + +def adapter_class( + name: str +): + def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]: + cls.name = name + preregistered_adapters.append(cls) + return cls + return decorator + + +class MessageSourceAdapter(metaclass=abc.ABCMeta): + name: str -class MessageSourceAdapter: bot_account_id: int - def __init__(self, config: dict): - pass + + config: dict + + ap: app.Application + + def __init__(self, config: dict, ap: app.Application): + self.config = config + self.ap = ap - def send_message( + async def send_message( self, target_type: str, target_id: str, @@ -24,7 +49,7 @@ def send_message( """ raise NotImplementedError - def reply_message( + async def reply_message( self, message_source: mirai.MessageEvent, message: mirai.MessageChain, @@ -39,14 +64,14 @@ def reply_message( """ raise NotImplementedError - def is_muted(self, group_id: int) -> bool: + async def is_muted(self, group_id: int) -> bool: """获取账号是否在指定群被禁言""" raise NotImplementedError def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] ): """注册事件监听器 @@ -59,7 +84,7 @@ def register_listener( def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] ): """注销事件监听器 @@ -69,11 +94,11 @@ def unregister_listener( """ raise NotImplementedError - def run_sync(self): - """以阻塞的方式运行适配器""" + async def run_async(self): + """异步运行""" raise NotImplementedError - def kill(self) -> bool: + async def kill(self) -> bool: """关闭适配器 Returns: diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py new file mode 100644 index 00000000..3d73c198 --- /dev/null +++ b/pkg/platform/manager.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import json +import os +import logging +import asyncio +import traceback + +from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ + FriendMessage, Image, MessageChain, Plain +import mirai +from ..platform import adapter as msadapter + +from ..core import app, entities as core_entities +from ..plugin import events + +# 控制QQ消息输入输出的类 +class PlatformManager: + + # adapter: msadapter.MessageSourceAdapter = None + adapters: list[msadapter.MessageSourceAdapter] = [] + + # modern + ap: app.Application = None + + def __init__(self, ap: app.Application = None): + + self.ap = ap + self.adapters = [] + + async def initialize(self): + + from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy + + async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter): + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PersonMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) + ) + + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter + ) + + async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter): + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PersonMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) + ) + + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter + ) + + async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.GroupMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) + ) + + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter + ) + + index = 0 + + for adap_cfg in self.ap.platform_cfg.data['platform-adapters']: + if adap_cfg['enable']: + self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}') + index += 1 + cfg_copy = adap_cfg.copy() + del cfg_copy['enable'] + adapter_name = cfg_copy['adapter'] + del cfg_copy['adapter'] + + found = False + + for adapter in msadapter.preregistered_adapters: + if adapter.name == adapter_name: + found = True + adapter_cls = adapter + + adapter_inst = adapter_cls( + cfg_copy, + self.ap + ) + self.adapters.append(adapter_inst) + + if adapter_name == 'yiri-mirai': + adapter_inst.register_listener( + StrangerMessage, + on_stranger_message + ) + + adapter_inst.register_listener( + FriendMessage, + on_friend_message + ) + adapter_inst.register_listener( + GroupMessage, + on_group_message + ) + + if not found: + raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name) + + if len(self.adapters) == 0: + self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') + + async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): + + if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): + + msg.insert( + 0, + At( + event.sender.id + ) + ) + + await adapter.reply_message( + event, + msg, + quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False + ) + + # 通知系统管理员 + # TODO delete + # async def notify_admin(self, message: str): + # await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) + + # async def notify_admin_message_chain(self, message: mirai.MessageChain): + # if self.ap.system_cfg.data['admin-sessions'] != []: + + # admin_list = [] + # for admin in self.ap.system_cfg.data['admin-sessions']: + # admin_list.append(admin) + + # for adm in admin_list: + # self.adapter.send_message( + # adm.split("_")[0], + # adm.split("_")[1], + # message + # ) + + async def run(self): + try: + tasks = [] + for adapter in self.adapters: + async def exception_wrapper(adapter): + try: + await adapter.run_async() + except Exception as e: + self.ap.logger.error('平台适配器运行出错: ' + str(e)) + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + tasks.append(exception_wrapper(adapter)) + + for task in tasks: + asyncio.create_task(task) + + except Exception as e: + self.ap.logger.error('平台适配器运行出错: ' + str(e)) + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + diff --git a/pkg/platform/sources/__init__.py b/pkg/platform/sources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py new file mode 100644 index 00000000..fc6b4fbe --- /dev/null +++ b/pkg/platform/sources/aiocqhttp.py @@ -0,0 +1,275 @@ +from __future__ import annotations +import typing +import asyncio +import traceback +import time +import datetime + +import mirai +import mirai.models.message as yiri_message +import aiocqhttp + +from .. import adapter +from ...pipeline.longtext.strategies import forward +from ...core import app + + +class AiocqhttpMessageConverter(adapter.MessageConverter): + + @staticmethod + def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: + msg_list = aiocqhttp.Message() + + msg_id = 0 + msg_time = None + + for msg in message_chain: + if type(msg) is mirai.Plain: + msg_list.append(aiocqhttp.MessageSegment.text(msg.text)) + elif type(msg) is yiri_message.Source: + msg_id = msg.id + msg_time = msg.time + elif type(msg) is mirai.Image: + msg_list.append(aiocqhttp.MessageSegment.image(msg.path)) + elif type(msg) is mirai.At: + msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) + elif type(msg) is mirai.AtAll: + msg_list.append(aiocqhttp.MessageSegment.at("all")) + elif type(msg) is mirai.Face: + msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) + elif type(msg) is mirai.Voice: + msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) + elif type(msg) is forward.Forward: + # print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送") + + for node in msg.node_list: + msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0]) + + else: + msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) + + return msg_list, msg_id, msg_time + + @staticmethod + def target2yiri(message: str, message_id: int = -1): + message = aiocqhttp.Message(message) + + yiri_msg_list = [] + + yiri_msg_list.append( + yiri_message.Source(id=message_id, time=datetime.datetime.now()) + ) + + for msg in message: + if msg.type == "at": + if msg.data["qq"] == "all": + yiri_msg_list.append(yiri_message.AtAll()) + else: + yiri_msg_list.append( + yiri_message.At( + target=msg.data["qq"], + ) + ) + elif msg.type == "text": + yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"])) + elif msg.type == "image": + yiri_msg_list.append(yiri_message.Image(url=msg.data["url"])) + + chain = mirai.MessageChain(yiri_msg_list) + + return chain + + +class AiocqhttpEventConverter(adapter.EventConverter): + + @staticmethod + def yiri2target(event: mirai.Event, bot_account_id: int): + + msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain) + + if type(event) is mirai.GroupMessage: + role = "member" + + if event.sender.permission == "ADMINISTRATOR": + role = "admin" + elif event.sender.permission == "OWNER": + role = "owner" + + payload = { + "post_type": "message", + "message_type": "group", + "time": int(msg_time.timestamp()), + "self_id": bot_account_id, + "sub_type": "normal", + "anonymous": None, + "font": 0, + "message": str(msg), + "raw_message": str(msg), + "sender": { + "age": 0, + "area": "", + "card": "", + "level": "", + "nickname": event.sender.member_name, + "role": role, + "sex": "unknown", + "title": "", + "user_id": event.sender.id, + }, + "user_id": event.sender.id, + "message_id": msg_id, + "group_id": event.group.id, + "message_seq": 0, + } + + return aiocqhttp.Event.from_payload(payload) + elif type(event) is mirai.FriendMessage: + + payload = { + "post_type": "message", + "message_type": "private", + "time": int(msg_time.timestamp()), + "self_id": bot_account_id, + "sub_type": "friend", + "target_id": bot_account_id, + "message": str(msg), + "raw_message": str(msg), + "font": 0, + "sender": { + "age": 0, + "nickname": event.sender.nickname, + "sex": "unknown", + "user_id": event.sender.id, + }, + "message_id": msg_id, + "user_id": event.sender.id, + } + + return aiocqhttp.Event.from_payload(payload) + + @staticmethod + def target2yiri(event: aiocqhttp.Event): + yiri_chain = AiocqhttpMessageConverter.target2yiri( + event.message, event.message_id + ) + + if event.message_type == "group": + permission = "MEMBER" + + if event.sender["role"] == "admin": + permission = "ADMINISTRATOR" + elif event.sender["role"] == "owner": + permission = "OWNER" + converted_event = mirai.GroupMessage( + sender=mirai.models.entities.GroupMember( + id=event.sender["user_id"], # message_seq 放哪? + member_name=event.sender["nickname"], + permission=permission, + group=mirai.models.entities.Group( + id=event.group_id, + name=event.sender["nickname"], + permission=mirai.models.entities.Permission.Member, + ), + special_title=event.sender["title"], + join_timestamp=0, + last_speak_timestamp=0, + mute_time_remaining=0, + ), + message_chain=yiri_chain, + time=event.time, + ) + return converted_event + elif event.message_type == "private": + return mirai.FriendMessage( + sender=mirai.models.entities.Friend( + id=event.sender["user_id"], + nickname=event.sender["nickname"], + remark="", + ), + message_chain=yiri_chain, + time=event.time, + ) + + +@adapter.adapter_class("aiocqhttp") +class AiocqhttpAdapter(adapter.MessageSourceAdapter): + + bot: aiocqhttp.CQHttp + + bot_account_id: int + + message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter() + event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter() + + config: dict + + ap: app.Application + + def __init__(self, config: dict, ap: app.Application): + self.config = config + + async def shutdown_trigger_placeholder(): + while True: + await asyncio.sleep(1) + + self.config['shutdown_trigger'] = shutdown_trigger_placeholder + + self.ap = ap + + self.bot = aiocqhttp.CQHttp() + + async def send_message( + self, target_type: str, target_id: str, message: mirai.MessageChain + ): + # TODO 实现发送消息 + return super().send_message(target_type, target_id, message) + + async def reply_message( + self, + message_source: mirai.MessageEvent, + message: mirai.MessageChain, + quote_origin: bool = False, + ): + + aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) + aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] + if quote_origin: + aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg + + return await self.bot.send( + aiocq_event, + aiocq_msg + ) + + async def is_muted(self, group_id: int) -> bool: + return False + + def register_listener( + self, + event_type: typing.Type[mirai.Event], + callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], + ): + async def on_message(event: aiocqhttp.Event): + self.bot_account_id = event.self_id + try: + return await callback(self.event_converter.target2yiri(event), self) + except: + traceback.print_exc() + + if event_type == mirai.GroupMessage: + self.bot.on_message("group")(on_message) + elif event_type == mirai.FriendMessage: + self.bot.on_message("private")(on_message) + + def unregister_listener( + self, + event_type: typing.Type[mirai.Event], + callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], + ): + return super().unregister_listener(event_type, callback) + + async def run_async(self): + await self.bot._server_app.run_task(**self.config) + + async def kill(self) -> bool: + return False diff --git a/pkg/qqbot/sources/nakuru.py b/pkg/platform/sources/nakuru.py similarity index 83% rename from pkg/qqbot/sources/nakuru.py rename to pkg/platform/sources/nakuru.py index 7196fe6f..0a419a06 100644 --- a/pkg/qqbot/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -1,3 +1,6 @@ +# 加了之后会导致:https://github.com/Lxns-Network/nakuru-project/issues/25 +# from __future__ import annotations + import asyncio import typing import traceback @@ -9,8 +12,7 @@ import nakuru.entities.components as nkc from .. import adapter as adapter_model -from ...qqbot import blob -from ...utils import context +from ...pipeline.longtext.strategies import forward class NakuruProjectMessageConverter(adapter_model.MessageConverter): @@ -49,7 +51,7 @@ def yiri2target(message_chain: mirai.MessageChain) -> list: nakuru_msg_list.append(nkc.Record.fromURL(component.url)) elif component.path is not None: nakuru_msg_list.append(nkc.Record.fromFileSystem(component.path)) - elif type(component) is blob.Forward: + elif type(component) is forward.Forward: # 转发消息 yiri_forward_node_list = component.node_list nakuru_forward_node_list = [] @@ -97,7 +99,7 @@ def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.Messag yiri_msg_list.append(mirai.AtAll()) else: pass - logging.debug("转换后的消息链: " + str(yiri_msg_list)) + # logging.debug("转换后的消息链: " + str(yiri_msg_list)) chain = mirai.MessageChain(yiri_msg_list) return chain @@ -157,6 +159,7 @@ def target2yiri(event: typing.Any) -> mirai.Event: raise Exception("未支持转换的事件类型: " + str(event)) +@adapter_model.adapter_class("nakuru") class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): """nakuru-project适配器""" bot: nakuru.CQHTTP @@ -167,34 +170,20 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): listener_list: list[dict] - def __init__(self, cfg: dict): + # ap: app.Application + + cfg: dict + + def __init__(self, cfg: dict, ap): """初始化nakuru-project的对象""" - self.bot = nakuru.CQHTTP(**cfg) + cfg['port'] = cfg['ws_port'] + del cfg['ws_port'] + self.cfg = cfg + self.ap = ap self.listener_list = [] - # nakuru库有bug,这个接口没法带access_token,会失败 - # 所以目前自行发请求 - - config = context.get_config_manager().data - - import requests - resp = requests.get( - url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']), - headers={ - 'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else "" - }, - timeout=5, - proxies=None - ) - if resp.status_code == 403: - logging.error("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") - raise Exception("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") - try: - self.bot_account_id = int(resp.json()['data']['user_id']) - except Exception as e: - logging.error("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e)) - raise Exception("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e)) + self.bot = nakuru.CQHTTP(**self.cfg) - def send_message( + async def send_message( self, target_type: str, target_id: str, @@ -227,9 +216,9 @@ def send_message( else: raise Exception("Unknown target type: " + target_type) - asyncio.run(task) + await task - def reply_message( + async def reply_message( self, message_source: mirai.MessageEvent, message: mirai.MessageChain, @@ -243,14 +232,14 @@ def reply_message( ) ) if type(message_source) is mirai.GroupMessage: - self.send_message( + await self.send_message( "group", message_source.sender.group.id, message, converted=True ) elif type(message_source) is mirai.FriendMessage: - self.send_message( + await self.send_message( "person", message_source.sender.id, message, @@ -268,14 +257,15 @@ def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): try: - logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback)) + + source_cls = NakuruProjectEventConverter.yiri2target(event_type) # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)): - callback(self.event_converter.target2yiri(source)) + async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): + await callback(self.event_converter.target2yiri(source), self) # 将包装函数和原函数的对应关系存入列表 self.listener_list.append( @@ -287,8 +277,7 @@ async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.even ) # 注册监听器 - self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper) - logging.debug("注册完成") + self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: traceback.print_exc() raise e @@ -296,7 +285,7 @@ async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.even def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ @@ -319,10 +308,26 @@ def unregister_listener( self.bot.event[nakuru_event_name] = new_event_list - def run_sync(self): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self.bot.run() + async def run_async(self): + try: + import requests + resp = requests.get( + url="http://{}:{}/get_login_info".format(self.cfg['host'], self.cfg['http_port']), + headers={ + 'Authorization': "Bearer " + self.cfg['token'] if 'token' in self.cfg else "" + }, + timeout=5, + proxies=None + ) + if resp.status_code == 403: + raise Exception("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") + self.bot_account_id = int(resp.json()['data']['user_id']) + except Exception as e: + raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确") + await self.bot._run() + self.ap.logger.info("运行 Nakuru 适配器") + while True: + await asyncio.sleep(1) def kill(self) -> bool: - return False + return False \ No newline at end of file diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py new file mode 100644 index 00000000..6d74d0ea --- /dev/null +++ b/pkg/platform/sources/qqbotpy.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import logging +import typing +import datetime +import asyncio +import re +import traceback +import json +import threading + +import mirai +import botpy +import botpy.message as botpy_message +import botpy.types.message as botpy_message_type + +from .. import adapter as adapter_model +from ...pipeline.longtext.strategies import forward +from ...core import app + + +class OfficialGroupMessage(mirai.GroupMessage): + pass + + +event_handler_mapping = { + mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], + mirai.FriendMessage: ["on_direct_message_create"], +} + + +cached_message_ids = {} +"""由于QQ官方的消息id是字符串,而YiriMirai的消息id是整数,所以需要一个索引来进行转换""" + +id_index = 0 + +def save_msg_id(message_id: str) -> int: + """保存消息id""" + global id_index, cached_message_ids + + crt_index = id_index + id_index += 1 + cached_message_ids[str(crt_index)] = message_id + return crt_index + +cached_member_openids = {} +"""QQ官方 用户的id是字符串,而YiriMirai的用户id是整数,所以需要一个索引来进行转换""" + +member_openid_index = 100 + +def save_member_openid(member_openid: str) -> int: + """保存用户id""" + global member_openid_index, cached_member_openids + + if member_openid in cached_member_openids.values(): + return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)] + + crt_index = member_openid_index + member_openid_index += 1 + cached_member_openids[str(crt_index)] = member_openid + return crt_index + +cached_group_openids = {} +"""QQ官方 群组的id是字符串,而YiriMirai的群组id是整数,所以需要一个索引来进行转换""" + +group_openid_index = 1000 + +def save_group_openid(group_openid: str) -> int: + """保存群组id""" + global group_openid_index, cached_group_openids + + if group_openid in cached_group_openids.values(): + return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)] + + crt_index = group_openid_index + group_openid_index += 1 + cached_group_openids[str(crt_index)] = group_openid + return crt_index + + +class OfficialMessageConverter(adapter_model.MessageConverter): + """QQ 官方消息转换器""" + @staticmethod + def yiri2target(message_chain: mirai.MessageChain): + """将 YiriMirai 的消息链转换为 QQ 官方消息""" + + msg_list = [] + if type(message_chain) is mirai.MessageChain: + msg_list = message_chain.__root__ + elif type(message_chain) is list: + msg_list = message_chain + else: + raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) + + offcial_messages: list[dict] = [] + """ + { + "type": "text", + "content": "Hello World!" + } + + { + "type": "image", + "content": "https://example.com/example.jpg" + } + """ + + # 遍历并转换 + for component in msg_list: + if type(component) is mirai.Plain: + offcial_messages.append({ + "type": "text", + "content": component.text + }) + elif type(component) is mirai.Image: + if component.url is not None: + offcial_messages.append( + { + "type": "image", + "content": component.url + } + ) + elif component.path is not None: + offcial_messages.append( + { + "type": "file_image", + "content": component.path + } + ) + elif type(component) is mirai.At: + offcial_messages.append( + { + "type": "at", + "content": "" + } + ) + elif type(component) is mirai.AtAll: + print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") + elif type(component) is mirai.Voice: + print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") + elif type(component) is forward.Forward: + # 转发消息 + yiri_forward_node_list = component.node_list + + # 遍历并转换 + for yiri_forward_node in yiri_forward_node_list: + try: + message_chain = yiri_forward_node.message_chain + + # 平铺 + offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain)) + except Exception as e: + import traceback + traceback.print_exc() + + return offcial_messages + + @staticmethod + def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain: + yiri_msg_list = [] + + # 存id + + yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now())) + + if type(message) is not botpy_message.DirectMessage: + yiri_msg_list.append(mirai.At(target=bot_account_id)) + + if hasattr(message, "mentions"): + for mention in message.mentions: + if mention.bot: + continue + + yiri_msg_list.append(mirai.At(target=mention.id)) + + for attachment in message.attachments: + if attachment.content_type == "image": + yiri_msg_list.append(mirai.Image(url=attachment.url)) + else: + logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。") + + content = re.sub(r"<@!\d+>", "", str(message.content)) + if content.strip() != "": + yiri_msg_list.append(mirai.Plain(text=content)) + + chain = mirai.MessageChain(yiri_msg_list) + + return chain + + +class OfficialEventConverter(adapter_model.EventConverter): + """事件转换器""" + @staticmethod + def yiri2target(event: typing.Type[mirai.Event]): + if event == mirai.GroupMessage: + return botpy_message.Message + elif event == mirai.FriendMessage: + return botpy_message.DirectMessage + else: + raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event)) + + @staticmethod + def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event: + import mirai.models.entities as mirai_entities + + if type(event) == botpy_message.Message: # 频道内,转群聊事件 + permission = "MEMBER" + + if '2' in event.member.roles: + permission = "ADMINISTRATOR" + elif '4' in event.member.roles: + permission = "OWNER" + + return mirai.GroupMessage( + sender=mirai_entities.GroupMember( + id=event.author.id, + member_name=event.author.username, + permission=permission, + group=mirai_entities.Group( + id=event.channel_id, + name=event.author.username, + permission=mirai_entities.Permission.Member + ), + special_title='', + join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()), + last_speak_timestamp=datetime.datetime.now().timestamp(), + mute_time_remaining=0, + ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), + ) + elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件 + return mirai.FriendMessage( + sender=mirai_entities.Friend( + id=event.guild_id, + nickname=event.author.username, + remark=event.author.username, + ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), + ) + elif type(event) == botpy_message.GroupMessage: + + replacing_member_id = save_member_openid(event.author.member_openid) + + return OfficialGroupMessage( + sender=mirai_entities.GroupMember( + id=replacing_member_id, + member_name=replacing_member_id, + permission="MEMBER", + group=mirai_entities.Group( + id=save_group_openid(event.group_openid), + name=replacing_member_id, + permission=mirai_entities.Permission.Member + ), + special_title='', + join_timestamp=int(0), + last_speak_timestamp=datetime.datetime.now().timestamp(), + mute_time_remaining=0, + ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), + ) + + +@adapter_model.adapter_class("qq-botpy") +class OfficialAdapter(adapter_model.MessageSourceAdapter): + """QQ 官方消息适配器""" + bot: botpy.Client = None + + bot_account_id: int = 0 + + message_converter: OfficialMessageConverter = OfficialMessageConverter() + # event_handler: adapter_model.EventHandler = adapter_model.EventHandler() + + cfg: dict = None + + cached_official_messages: dict = {} + """缓存的 qq-botpy 框架消息对象 + + message_id: botpy_message.Message | botpy_message.DirectMessage + """ + + ap: app.Application + + def __init__(self, cfg: dict, ap: app.Application): + """初始化适配器""" + self.cfg = cfg + self.ap = ap + + switchs = {} + + for intent in cfg['intents']: + switchs[intent] = True + + del cfg['intents'] + + intents = botpy.Intents(**switchs) + + self.bot = botpy.Client(intents=intents) + + async def send_message( + self, + target_type: str, + target_id: str, + message: mirai.MessageChain + ): + pass + + async def reply_message( + self, + message_source: mirai.MessageEvent, + message: mirai.MessageChain, + quote_origin: bool = False + ): + message_list = self.message_converter.yiri2target(message) + tasks = [] + + msg_seq = 1 + + for msg in message_list: + args = {} + + if msg['type'] == 'text': + args['content'] = msg['content'] + elif msg['type'] == 'image': + args['image'] = msg['content'] + elif msg['type'] == 'file_image': + args['file_image'] = msg["content"] + else: + continue + + if quote_origin: + args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)]) + + if type(message_source) == mirai.GroupMessage: + args['channel_id'] = str(message_source.sender.group.id) + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] + await self.bot.api.post_message(**args) + elif type(message_source) == mirai.FriendMessage: + args['guild_id'] = str(message_source.sender.id) + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] + await self.bot.api.post_dms(**args) + elif type(message_source) == OfficialGroupMessage: + # args['guild_id'] = str(message_source.sender.group.id) + # args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] + # await self.bot.api.post_message(**args) + if 'image' in args or 'file_image' in args: + continue + args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)] + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] + args['msg_seq'] = msg_seq + msg_seq += 1 + await self.bot.api.post_group_message( + **args + ) + + + async def is_muted(self, group_id: int) -> bool: + return False + + def register_listener( + self, + event_type: typing.Type[mirai.Event], + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] + ): + + try: + + async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]): + self.cached_official_messages[str(message.id)] = message + await callback(OfficialEventConverter.target2yiri(message), self) + + for event_handler in event_handler_mapping[event_type]: + setattr(self.bot, event_handler, wrapper) + except Exception as e: + traceback.print_exc() + raise e + + def unregister_listener( + self, + event_type: typing.Type[mirai.Event], + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] + ): + delattr(self.bot, event_handler_mapping[event_type]) + + async def run_async(self): + self.ap.logger.info("运行 QQ 官方适配器") + await self.bot.start( + **self.cfg + ) + + def kill(self) -> bool: + return False diff --git a/pkg/qqbot/sources/yirimirai.py b/pkg/platform/sources/yirimirai.py similarity index 77% rename from pkg/qqbot/sources/yirimirai.py rename to pkg/platform/sources/yirimirai.py index 7828be18..7768dcf0 100644 --- a/pkg/qqbot/sources/yirimirai.py +++ b/pkg/platform/sources/yirimirai.py @@ -6,14 +6,18 @@ from mirai.bot import MiraiRunner from .. import adapter as adapter_model +from ...core import app +@adapter_model.adapter_class("yiri-mirai") class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): """YiriMirai适配器""" bot: mirai.Mirai - def __init__(self, config: dict): + def __init__(self, config: dict, ap: app.Application): """初始化YiriMirai的对象""" + self.ap = ap + self.config = config if 'adapter' not in config or \ config['adapter'] == 'WebSocketAdapter': self.bot = mirai.Mirai( @@ -36,7 +40,7 @@ def __init__(self, config: dict): else: raise Exception('Unknown adapter for YiriMirai: ' + config['adapter']) - def send_message( + async def send_message( self, target_type: str, target_id: str, @@ -57,9 +61,9 @@ def send_message( else: raise Exception('Unknown target type: ' + target_type) - asyncio.run(task) + await task - def reply_message( + async def reply_message( self, message_source: mirai.MessageEvent, message: mirai.MessageChain, @@ -72,11 +76,10 @@ def reply_message( message (mirai.MessageChain): YiriMirai库的消息链 quote_origin (bool, optional): 是否引用原消息. Defaults to False. """ - asyncio.run(self.bot.send(message_source, message, quote_origin)) + await self.bot.send(message_source, message, quote_origin) - def is_muted(self, group_id: int) -> bool: - result = self.bot.member_info(target=group_id, member_id=self.bot.qq).get() - result = asyncio.run(result) + async def is_muted(self, group_id: int) -> bool: + result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get() if result.mute_time_remaining > 0: return True return False @@ -84,7 +87,7 @@ def is_muted(self, group_id: int) -> bool: def register_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): """注册事件监听器 @@ -92,12 +95,14 @@ def register_listener( event_type (typing.Type[mirai.Event]): YiriMirai事件类型 callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 """ - self.bot.on(event_type)(callback) + async def wrapper(event: mirai.Event): + await callback(event, self) + self.bot.on(event_type)(wrapper) def unregister_listener( self, event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event], None] + callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] ): """注销事件监听器 @@ -111,13 +116,9 @@ def unregister_listener( bus.unsubscribe(event_type, callback) - def run_sync(self): - """运行YiriMirai""" + async def run_async(self): + self.bot_account_id = self.bot.qq + return await MiraiRunner(self.bot)._run() - # 创建新的 - loop = asyncio.new_event_loop() - - loop.run_until_complete(MiraiRunner(self.bot)._run()) - - def kill(self) -> bool: + async def kill(self) -> bool: return False diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py new file mode 100644 index 00000000..a982232f --- /dev/null +++ b/pkg/plugin/context.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import typing +import abc +import pydantic + +from . import events +from ..provider.tools import entities as tools_entities +from ..core import app + + +class BasePlugin(metaclass=abc.ABCMeta): + """插件基类""" + + host: APIHost + + +class APIHost: + """QChatGPT API 宿主""" + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + def require_ver( + self, + ge: str, + le: str='v999.999.999', + ) -> bool: + """插件版本要求装饰器 + + Args: + ge (str): 最低版本要求 + le (str, optional): 最高版本要求 + + Returns: + bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 + """ + qchatgpt_version = "" + + try: + qchatgpt_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号 + except: + return False + + if self.ap.ver_mgr.compare_version_str(qchatgpt_version, ge) < 0 or \ + (self.ap.ver_mgr.compare_version_str(qchatgpt_version, le) > 0): + raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, qchatgpt_version)) + + return True + + +class EventContext: + """事件上下文, 保存此次事件运行的信息""" + + eid = 0 + """事件编号""" + + host: APIHost = None + + event: events.BaseEventModel = None + + __prevent_default__ = False + """是否阻止默认行为""" + + __prevent_postorder__ = False + """是否阻止后续插件的执行""" + + __return_value__ = {} + """ 返回值 + 示例: + { + "example": [ + 'value1', + 'value2', + 3, + 4, + { + 'key1': 'value1', + }, + ['value1', 'value2'] + ] + } + """ + + def add_return(self, key: str, ret): + """添加返回值""" + if key not in self.__return_value__: + self.__return_value__[key] = [] + self.__return_value__[key].append(ret) + + def get_return(self, key: str) -> list: + """获取key的所有返回值""" + if key in self.__return_value__: + return self.__return_value__[key] + return None + + def get_return_value(self, key: str): + """获取key的首个返回值""" + if key in self.__return_value__: + return self.__return_value__[key][0] + return None + + def prevent_default(self): + """阻止默认行为""" + self.__prevent_default__ = True + + def prevent_postorder(self): + """阻止后续插件执行""" + self.__prevent_postorder__ = True + + def is_prevented_default(self): + """是否阻止默认行为""" + return self.__prevent_default__ + + def is_prevented_postorder(self): + """是否阻止后序插件执行""" + return self.__prevent_postorder__ + + def __init__(self, host: APIHost, event: events.BaseEventModel): + + self.eid = EventContext.eid + self.host = host + self.event = event + self.__prevent_default__ = False + self.__prevent_postorder__ = False + self.__return_value__ = {} + EventContext.eid += 1 + + +class RuntimeContainer(pydantic.BaseModel): + """运行时的插件容器 + + 运行期间存储单个插件的信息 + """ + + plugin_name: str + """插件名称""" + + plugin_description: str + """插件描述""" + + plugin_version: str + """插件版本""" + + plugin_author: str + """插件作者""" + + plugin_source: str + """插件源码地址""" + + main_file: str + """插件主文件路径""" + + pkg_path: str + """插件包路径""" + + plugin_class: typing.Type[BasePlugin] = None + """插件类""" + + enabled: typing.Optional[bool] = True + """是否启用""" + + priority: typing.Optional[int] = 0 + """优先级""" + + plugin_inst: typing.Optional[BasePlugin] = None + """插件实例""" + + event_handlers: dict[typing.Type[events.BaseEventModel], typing.Callable[ + [BasePlugin, EventContext], typing.Awaitable[None] + ]] = {} + """事件处理器""" + + content_functions: list[tools_entities.LLMFunction] = [] + """内容函数""" + + class Config: + arbitrary_types_allowed = True + + def to_setting_dict(self): + return { + 'name': self.plugin_name, + 'description': self.plugin_description, + 'version': self.plugin_version, + 'author': self.plugin_author, + 'source': self.plugin_source, + 'main_file': self.main_file, + 'pkg_path': self.pkg_path, + 'priority': self.priority, + 'enabled': self.enabled, + } + + def set_from_setting_dict( + self, + setting: dict + ): + self.plugin_source = setting['source'] + self.priority = setting['priority'] + self.enabled = setting['enabled'] + + for function in self.content_functions: + function.enable = self.enabled diff --git a/pkg/plugin/errors.py b/pkg/plugin/errors.py new file mode 100644 index 00000000..bd6199e3 --- /dev/null +++ b/pkg/plugin/errors.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +class PluginSystemError(Exception): + + message: str + + def __init__(self, message: str): + self.message = message + + def __str__(self): + return self.message + + +class PluginNotFoundError(PluginSystemError): + + def __init__(self, message: str): + super().__init__(f"未找到插件: {message}") + + +class PluginInstallerError(PluginSystemError): + + def __init__(self, message: str): + super().__init__(f"安装器操作错误: {message}") diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py new file mode 100644 index 00000000..d414dafb --- /dev/null +++ b/pkg/plugin/events.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import typing + +import pydantic +import mirai + +from ..core import entities as core_entities +from ..provider import entities as llm_entities + + +class BaseEventModel(pydantic.BaseModel): + + query: core_entities.Query | None + + class Config: + arbitrary_types_allowed = True + + +class PersonMessageReceived(BaseEventModel): + """收到任何私聊消息时""" + + launcher_type: str + """发起对象类型(group/person)""" + + launcher_id: int + """发起对象ID(群号/QQ号)""" + + sender_id: int + """发送者ID(QQ号)""" + + message_chain: mirai.MessageChain + + +class GroupMessageReceived(BaseEventModel): + """收到任何群聊消息时""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + message_chain: mirai.MessageChain + + +class PersonNormalMessageReceived(BaseEventModel): + """判断为应该处理的私聊普通消息时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + text_message: str + + alter: typing.Optional[str] = None + """修改后的消息文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class PersonCommandSent(BaseEventModel): + """判断为应该处理的私聊命令时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + command: str + + params: list[str] + + text_message: str + + is_admin: bool + + alter: typing.Optional[str] = None + """修改后的完整命令文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class GroupNormalMessageReceived(BaseEventModel): + """判断为应该处理的群聊普通消息时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + text_message: str + + alter: typing.Optional[str] = None + """修改后的消息文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class GroupCommandSent(BaseEventModel): + """判断为应该处理的群聊命令时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + command: str + + params: list[str] + + text_message: str + + is_admin: bool + + alter: typing.Optional[str] = None + """修改后的完整命令文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class NormalMessageResponded(BaseEventModel): + """回复普通消息时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + session: core_entities.Session + """会话对象""" + + prefix: str + """回复消息的前缀""" + + response_text: str + """回复消息的文本""" + + finish_reason: str + """响应结束原因""" + + funcs_called: list[str] + """调用的函数列表""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class PromptPreProcessing(BaseEventModel): + """会话中的Prompt预处理时触发""" + + session_name: str + + default_prompt: list[llm_entities.Message] + """此对话的情景预设,可修改""" + + prompt: list[llm_entities.Message] + """此对话现有消息记录,可修改""" diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index bf41003a..6149da62 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -1,578 +1,5 @@ -# 插件管理模块 -import asyncio -import logging -import importlib -import os -import pkgutil -import sys -import shutil -import traceback -import time -import re +from . events import * +from . context import EventContext, APIHost as PluginHost -from ..utils import updater as updater -from ..utils import network as network -from ..utils import context as context -from ..plugin import switch as switch -from ..plugin import settings as settings -from ..qqbot import adapter as msadapter -from ..plugin import metadata as metadata - -from mirai import Mirai -import requests - -from CallingGPT.session.session import Session - -__plugins__ = {} -"""插件列表 - -示例: -{ - "example": { - "path": "plugins/example/main.py", - "enabled: True, - "name": "example", - "description": "example", - "version": "0.0.1", - "author": "RockChinQ", - "class": , - "hooks": { - "person_message": [ - - ] - }, - "instance": None - } -} -""" - -__plugins_order__ = [] -"""插件顺序""" - -__enable_content_functions__ = True -"""是否启用内容函数""" - -__callable_functions__ = [] -"""供GPT调用的函数结构""" - -__function_inst_map__: dict[str, callable] = {} -"""函数名:实例 映射""" - - -def generate_plugin_order(): - """根据__plugin__生成插件初始顺序,无视是否启用""" - global __plugins_order__ - __plugins_order__ = [] - for plugin_name in __plugins__: - __plugins_order__.append(plugin_name) - - -def iter_plugins(): - """按照顺序迭代插件""" - for plugin_name in __plugins_order__: - if plugin_name not in __plugins__: - continue - yield __plugins__[plugin_name] - - -def iter_plugins_name(): - """迭代插件名""" - for plugin_name in __plugins_order__: - yield plugin_name - - -__current_module_path__ = "" - - -def walk_plugin_path(module, prefix="", path_prefix=""): - global __current_module_path__ - """遍历插件路径""" - for item in pkgutil.iter_modules(module.__path__): - if item.ispkg: - logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name)) - walk_plugin_path( - __import__(module.__name__ + "." + item.name, fromlist=[""]), - prefix + item.name + ".", - path_prefix + item.name + "/", - ) - else: - try: - logging.debug( - "扫描插件模块: plugins/{}".format(path_prefix + item.name + ".py") - ) - __current_module_path__ = "plugins/" + path_prefix + item.name + ".py" - - importlib.import_module(module.__name__ + "." + item.name) - logging.debug( - "加载模块: plugins/{} 成功".format(path_prefix + item.name + ".py") - ) - except: - logging.error( - "加载模块: plugins/{} 失败: {}".format( - path_prefix + item.name + ".py", sys.exc_info() - ) - ) - traceback.print_exc() - - -def load_plugins(): - """加载插件""" - logging.debug("加载插件") - PluginHost() - walk_plugin_path(__import__("plugins")) - - logging.debug(__plugins__) - - # 加载开关数据 - switch.load_switch() - - # 生成初始顺序 - generate_plugin_order() - # 加载插件顺序 - settings.load_settings() - - logging.debug("registered plugins: {}".format(__plugins__)) - - # 输出已注册的内容函数列表 - logging.debug("registered content functions: {}".format(__callable_functions__)) - logging.debug("function instance map: {}".format(__function_inst_map__)) - - # 迁移插件源地址记录 - metadata.do_plugin_git_repo_migrate() - - -def initialize_plugins(): - """初始化插件""" - logging.debug("初始化插件") - import pkg.plugin.models as models - - successfully_initialized_plugins = [] - - for plugin in iter_plugins(): - # if not plugin['enabled']: - # continue - try: - models.__current_registering_plugin__ = plugin["name"] - plugin["instance"] = plugin["class"](plugin_host=context.get_plugin_host()) - # logging.info("插件 {} 已初始化".format(plugin['name'])) - successfully_initialized_plugins.append(plugin["name"]) - except: - logging.error("插件{}初始化时发生错误: {}".format(plugin["name"], sys.exc_info())) - logging.debug(traceback.format_exc()) - - logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins))) - - -def unload_plugins(): - """卸载插件""" - # 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行 - # for plugin in __plugins__.values(): - # if plugin['enabled'] and plugin['instance'] is not None: - # if not hasattr(plugin['instance'], '__del__'): - # logging.warning("插件{}没有定义析构函数".format(plugin['name'])) - # else: - # try: - # plugin['instance'].__del__() - # logging.info("卸载插件: {}".format(plugin['name'])) - # plugin['instance'] = None - # except: - # logging.error("插件{}卸载时发生错误: {}".format(plugin['name'], sys.exc_info())) - - -def get_github_plugin_repo_label(repo_url: str) -> list[str]: - """获取username, repo""" - - # 提取 username/repo , 正则表达式 - repo = re.findall( - r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", - repo_url, - ) - - if len(repo) > 0: # github - return repo[0].split("/") - else: - return None - - -def download_plugin_source_code(repo_url: str, target_path: str) -> str: - """下载插件源码""" - # 检查源类型 - - # 提取 username/repo , 正则表达式 - repo = get_github_plugin_repo_label(repo_url) - - target_path += repo[1] - - if repo is not None: # github - logging.info("从 GitHub 下载插件源码...") - - zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" - - zip_resp = requests.get( - url=zipball_url, proxies=network.wrapper_proxies(), stream=True - ) - - if zip_resp.status_code != 200: - raise Exception("下载源码失败: {}".format(zip_resp.text)) - - if os.path.exists("temp/" + target_path): - shutil.rmtree("temp/" + target_path) - - if os.path.exists(target_path): - shutil.rmtree(target_path) - - os.makedirs("temp/" + target_path) - - with open("temp/" + target_path + "/source.zip", "wb") as f: - for chunk in zip_resp.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - - logging.info("下载完成, 解压...") - import zipfile - - with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: - zip_ref.extractall("temp/" + target_path) - os.remove("temp/" + target_path + "/source.zip") - - # 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo - import glob - - # 获取解压后的文件夹名 - unzip_dir = glob.glob("temp/" + target_path + "/*")[0] - - # 复制到 plugins/repo - shutil.copytree(unzip_dir, target_path + "/") - - # 删除解压后的文件夹 - shutil.rmtree(unzip_dir) - - logging.info("解压完成") - else: - raise Exception("暂不支持的源类型,请使用 GitHub 仓库发行插件。") - - return repo[1] - - -def check_requirements(path: str): - # 检查此目录是否包含requirements.txt - if os.path.exists(path + "/requirements.txt"): - logging.info("检测到requirements.txt,正在安装依赖") - import pkg.utils.pkgmgr - - pkg.utils.pkgmgr.install_requirements(path + "/requirements.txt") - - import pkg.utils.log as log - - log.reset_logging() - - -def install_plugin(repo_url: str): - """安装插件,从git储存库获取并解决依赖""" - - repo_label = download_plugin_source_code(repo_url, "plugins/") - - check_requirements("plugins/" + repo_label) - - metadata.set_plugin_metadata(repo_label, repo_url, int(time.time()), "HEAD") - - # 上报安装记录 - context.get_center_v2_api().plugin.post_install_record( - plugin={ - "name": "unknown", - "remote": repo_url, - "author": "unknown", - "version": "HEAD", - } - ) - - -def uninstall_plugin(plugin_name: str) -> str: - """卸载插件""" - if plugin_name not in __plugins__: - raise Exception("插件不存在") - - plugin_info = get_plugin_info_for_audit(plugin_name) - - # 获取文件夹路径 - plugin_path = __plugins__[plugin_name]["path"].replace("\\", "/") - - # 剪切路径为plugins/插件名 - plugin_path = plugin_path.split("plugins/")[1].split("/")[0] - - # 删除文件夹 - shutil.rmtree("plugins/" + plugin_path) - - # 上报卸载记录 - context.get_center_v2_api().plugin.post_remove_record( - plugin=plugin_info - ) - - return "plugins/" + plugin_path - - -def update_plugin(plugin_name: str): - """更新插件""" - # 检查是否有远程地址记录 - plugin_path_name = get_plugin_path_name_by_plugin_name(plugin_name) - - meta = metadata.get_plugin_metadata(plugin_path_name) - - if meta == {}: - raise Exception("没有此插件元数据信息,无法更新") - - old_plugin_info = get_plugin_info_for_audit(plugin_name) - - context.get_center_v2_api().plugin.post_update_record( - plugin=old_plugin_info, - old_version=old_plugin_info['version'], - new_version='HEAD', - ) - - remote_url = meta["source"] - if ( - remote_url == "https://github.com/RockChinQ/QChatGPT" - or remote_url == "https://gitee.com/RockChin/QChatGPT" - or remote_url == "" - or remote_url is None - or remote_url == "http://github.com/RockChinQ/QChatGPT" - or remote_url == "http://gitee.com/RockChin/QChatGPT" - ): - raise Exception("插件没有远程地址记录,无法更新") - - # 重新安装插件 - logging.info("正在重新安装插件以进行更新...") - - install_plugin(remote_url) - - -def get_plugin_name_by_path_name(plugin_path_name: str) -> str: - for k, v in __plugins__.items(): - if v["path"] == "plugins/" + plugin_path_name + "/main.py": - return k - return None - - -def get_plugin_path_name_by_plugin_name(plugin_name: str) -> str: - if plugin_name not in __plugins__: - return None - - plugin_main_module_path = __plugins__[plugin_name]["path"] - - plugin_main_module_path = plugin_main_module_path.replace("\\", "/") - - spt = plugin_main_module_path.split("/") - - return spt[1] - - -def get_plugin_info_for_audit(plugin_name: str) -> dict: - """获取插件信息""" - if plugin_name not in __plugins__: - return {} - plugin = __plugins__[plugin_name] - - name = plugin["name"] - meta = metadata.get_plugin_metadata(get_plugin_path_name_by_plugin_name(name)) - remote = meta["source"] if meta != {} else "" - author = plugin["author"] - version = plugin["version"] - - return { - "name": name, - "remote": remote, - "author": author, - "version": version, - } - - -class EventContext: - """事件上下文""" - - eid = 0 - """事件编号""" - - name = "" - - __prevent_default__ = False - """是否阻止默认行为""" - - __prevent_postorder__ = False - """是否阻止后续插件的执行""" - - __return_value__ = {} - """ 返回值 - 示例: - { - "example": [ - 'value1', - 'value2', - 3, - 4, - { - 'key1': 'value1', - }, - ['value1', 'value2'] - ] - } - """ - - def add_return(self, key: str, ret): - """添加返回值""" - if key not in self.__return_value__: - self.__return_value__[key] = [] - self.__return_value__[key].append(ret) - - def get_return(self, key: str) -> list: - """获取key的所有返回值""" - if key in self.__return_value__: - return self.__return_value__[key] - return None - - def get_return_value(self, key: str): - """获取key的首个返回值""" - if key in self.__return_value__: - return self.__return_value__[key][0] - return None - - def prevent_default(self): - """阻止默认行为""" - self.__prevent_default__ = True - - def prevent_postorder(self): - """阻止后续插件执行""" - self.__prevent_postorder__ = True - - def is_prevented_default(self): - """是否阻止默认行为""" - return self.__prevent_default__ - - def is_prevented_postorder(self): - """是否阻止后序插件执行""" - return self.__prevent_postorder__ - - def __init__(self, name: str): - self.name = name - self.eid = EventContext.eid - self.__prevent_default__ = False - self.__prevent_postorder__ = False - self.__return_value__ = {} - EventContext.eid += 1 - - -def emit(event_name: str, **kwargs) -> EventContext: - """触发事件""" - import pkg.utils.context as context - - if context.get_plugin_host() is None: - return None - return context.get_plugin_host().emit(event_name, **kwargs) - - -class PluginHost: - """插件宿主""" - - def __init__(self): - """初始化插件宿主""" - context.set_plugin_host(self) - self.calling_gpt_session = Session([]) - - def get_runtime_context(self) -> context: - """获取运行时上下文(pkg.utils.context模块的对象) - - 此上下文用于和主程序其他模块交互(数据库、QQ机器人、OpenAI接口等) - 详见pkg.utils.context模块 - 其中的context变量保存了其他重要模块的类对象,可以使用这些对象进行交互 - """ - return context - - def get_bot(self) -> Mirai: - """获取机器人对象""" - return context.get_qqbot_manager().bot - - def get_bot_adapter(self) -> msadapter.MessageSourceAdapter: - """获取消息源适配器""" - return context.get_qqbot_manager().adapter - - def send_person_message(self, person, message): - """发送私聊消息""" - self.get_bot_adapter().send_message("person", person, message) - - def send_group_message(self, group, message): - """发送群消息""" - self.get_bot_adapter().send_message("group", group, message) - - def notify_admin(self, message): - """通知管理员""" - context.get_qqbot_manager().notify_admin(message) - - def emit(self, event_name: str, **kwargs) -> EventContext: - """触发事件""" - import json - - event_context = EventContext(event_name) - logging.debug("触发事件: {} ({})".format(event_name, event_context.eid)) - - emitted_plugins = [] - for plugin in iter_plugins(): - if not plugin["enabled"]: - continue - - # if plugin['instance'] is None: - # # 从关闭状态切到开启状态之后,重新加载插件 - # try: - # plugin['instance'] = plugin["class"](plugin_host=self) - # logging.info("插件 {} 已初始化".format(plugin['name'])) - # except: - # logging.error("插件 {} 初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) - # continue - - if "hooks" not in plugin or event_name not in plugin["hooks"]: - continue - - emitted_plugins.append(plugin['name']) - - hooks = [] - if event_name in plugin["hooks"]: - hooks = plugin["hooks"][event_name] - for hook in hooks: - try: - already_prevented_default = event_context.is_prevented_default() - - kwargs["host"] = context.get_plugin_host() - kwargs["event"] = event_context - - hook(plugin["instance"], **kwargs) - - if ( - event_context.is_prevented_default() - and not already_prevented_default - ): - logging.debug( - "插件 {} 已要求阻止事件 {} 的默认行为".format(plugin["name"], event_name) - ) - - except Exception as e: - logging.error("插件{}响应事件{}时发生错误".format(plugin["name"], event_name)) - logging.error(traceback.format_exc()) - - # print("done:{}".format(plugin['name'])) - if event_context.is_prevented_postorder(): - logging.debug("插件 {} 阻止了后序插件的执行".format(plugin["name"])) - break - - logging.debug( - "事件 {} ({}) 处理完毕,返回值: {}".format( - event_name, event_context.eid, event_context.__return_value__ - ) - ) - - if len(emitted_plugins) > 0: - plugins_info = [get_plugin_info_for_audit(p) for p in emitted_plugins] - - context.get_center_v2_api().usage.post_event_record( - plugins=plugins_info, - event_name=event_name, - ) - - return event_context +def emit(*args, **kwargs): + print('插件调用了已弃用的函数 pkg.plugin.host.emit()') \ No newline at end of file diff --git a/pkg/plugin/installer.py b/pkg/plugin/installer.py new file mode 100644 index 00000000..6a089438 --- /dev/null +++ b/pkg/plugin/installer.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app + + +class PluginInstaller(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + raise NotImplementedError + + @abc.abstractmethod + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + raise NotImplementedError + + @abc.abstractmethod + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + raise NotImplementedError diff --git a/pkg/plugin/installers/__init__.py b/pkg/plugin/installers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/plugin/installers/github.py b/pkg/plugin/installers/github.py new file mode 100644 index 00000000..8908f181 --- /dev/null +++ b/pkg/plugin/installers/github.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import re +import os +import shutil +import zipfile + +import requests + +from .. import installer, errors +from ...utils import pkgmgr + + +class GitHubRepoInstaller(installer.PluginInstaller): + + def get_github_plugin_repo_label(self, repo_url: str) -> list[str]: + """获取username, repo""" + + # 提取 username/repo , 正则表达式 + repo = re.findall( + r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", + repo_url, + ) + + if len(repo) > 0: # github + return repo[0].split("/") + else: + return None + + async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str: + """下载插件源码""" + # 检查源类型 + + # 提取 username/repo , 正则表达式 + repo = self.get_github_plugin_repo_label(repo_url) + + target_path += repo[1] + + if repo is not None: # github + self.ap.logger.debug("正在下载源码...") + + zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" + + zip_resp = requests.get( + url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True + ) + + if zip_resp.status_code != 200: + raise Exception("下载源码失败: {}".format(zip_resp.text)) + + if os.path.exists("temp/" + target_path): + shutil.rmtree("temp/" + target_path) + + if os.path.exists(target_path): + shutil.rmtree(target_path) + + os.makedirs("temp/" + target_path) + + with open("temp/" + target_path + "/source.zip", "wb") as f: + for chunk in zip_resp.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + + self.ap.logger.debug("解压中...") + + with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: + zip_ref.extractall("temp/" + target_path) + os.remove("temp/" + target_path + "/source.zip") + + # 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo + import glob + + # 获取解压后的文件夹名 + unzip_dir = glob.glob("temp/" + target_path + "/*")[0] + + # 复制到 plugins/repo + shutil.copytree(unzip_dir, target_path + "/") + + # 删除解压后的文件夹 + shutil.rmtree(unzip_dir) + + self.ap.logger.debug("源码下载完成。") + else: + raise errors.PluginInstallerError('仅支持GitHub仓库地址') + + return repo[1] + + async def install_requirements(self, path: str): + if os.path.exists(path + "/requirements.txt"): + pkgmgr.install_requirements(path + "/requirements.txt") + + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + repo_label = await self.download_plugin_source_code(plugin_source, "plugins/") + + await self.install_requirements("plugins/" + repo_label) + + await self.ap.plugin_mgr.setting.record_installed_plugin_source( + "plugins/"+repo_label+'/', plugin_source + ) + + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is None: + raise errors.PluginInstallerError('插件不存在或未成功加载') + else: + shutil.rmtree(plugin_container.pkg_path) + + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is None: + raise errors.PluginInstallerError('插件不存在或未成功加载') + else: + if plugin_container.plugin_source: + plugin_source = plugin_container.plugin_source + + await self.install_plugin(plugin_source) + + else: + raise errors.PluginInstallerError('插件无源码信息,无法更新') diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py new file mode 100644 index 00000000..d74bcde7 --- /dev/null +++ b/pkg/plugin/loader.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from abc import ABCMeta + +import typing +import abc + +from ..core import app +from . import context, events + + +class PluginLoader(metaclass=abc.ABCMeta): + """插件加载器""" + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def load_plugins(self) -> list[context.RuntimeContainer]: + pass + diff --git a/pkg/plugin/loaders/__init__.py b/pkg/plugin/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/plugin/loaders/legacy.py b/pkg/plugin/loaders/legacy.py new file mode 100644 index 00000000..9bbee7c0 --- /dev/null +++ b/pkg/plugin/loaders/legacy.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import typing +import pkgutil +import importlib +import traceback + +from CallingGPT.entities.namespace import get_func_schema + +from .. import loader, events, context, models, host +from ...core import entities as core_entities +from ...provider.tools import entities as tools_entities + + +class PluginLoader(loader.PluginLoader): + """加载 plugins/ 目录下的插件""" + + _current_pkg_path = '' + + _current_module_path = '' + + _current_container: context.RuntimeContainer = None + + containers: list[context.RuntimeContainer] = [] + + async def initialize(self): + """初始化""" + setattr(models, 'register', self.register) + setattr(models, 'on', self.on) + setattr(models, 'func', self.func) + + def register( + self, + name: str, + description: str, + version: str, + author: str + ) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]: + self.ap.logger.debug(f'注册插件 {name} {version} by {author}') + container = context.RuntimeContainer( + plugin_name=name, + plugin_description=description, + plugin_version=version, + plugin_author=author, + plugin_source='', + pkg_path=self._current_pkg_path, + main_file=self._current_module_path, + event_handlers={}, + content_functions=[], + ) + + self._current_container = container + + def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]: + container.plugin_class = cls + return cls + + return wrapper + + def on( + self, + event: typing.Type[events.BaseEventModel] + ) -> typing.Callable[[typing.Callable], typing.Callable]: + """注册过时的事件处理器""" + self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: + + async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None: + args = { + 'host': ctx.host, + 'event': ctx, + } + + # 把 ctx.event 所有的属性都放到 args 里 + for k, v in ctx.event.dict().items(): + args[k] = v + + func(plugin, **args) + + self._current_container.event_handlers[event] = handler + + return func + + return wrapper + + def func( + self, + name: str=None, + ) -> typing.Callable: + """注册过时的内容函数""" + self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: + + function_schema = get_func_schema(func) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + + async def handler( + query: core_entities.Query, + *args, + **kwargs + ): + return func(*args, **kwargs) + + llm_function = tools_entities.LLMFunction( + name=function_name, + human_desc='', + description=function_schema['description'], + enable=True, + parameters=function_schema['parameters'], + func=handler, + ) + + self._current_container.content_functions.append(llm_function) + + return func + + return wrapper + + async def _walk_plugin_path( + self, + module, + prefix='', + path_prefix='' + ): + """遍历插件路径 + """ + for item in pkgutil.iter_modules(module.__path__): + if item.ispkg: + await self._walk_plugin_path( + __import__(module.__name__ + "." + item.name, fromlist=[""]), + prefix + item.name + ".", + path_prefix + item.name + "/", + ) + else: + try: + self._current_pkg_path = "plugins/" + path_prefix + self._current_module_path = "plugins/" + path_prefix + item.name + ".py" + + self._current_container = None + + importlib.import_module(module.__name__ + "." + item.name) + + if self._current_container is not None: + self.containers.append(self._current_container) + self.ap.logger.debug(f'插件 {self._current_container} 已加载') + except: + self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') + traceback.print_exc() + + async def load_plugins(self) -> list[context.RuntimeContainer]: + """加载插件 + """ + await self._walk_plugin_path(__import__("plugins", fromlist=[""])) + + return self.containers diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py new file mode 100644 index 00000000..243d442e --- /dev/null +++ b/pkg/plugin/manager.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import typing +import traceback + +from ..core import app +from . import context, loader, events, installer, setting, models +from .loaders import legacy +from .installers import github + + +class PluginManager: + + ap: app.Application + + loader: loader.PluginLoader + + installer: installer.PluginInstaller + + setting: setting.SettingManager + + api_host: context.APIHost + + plugins: list[context.RuntimeContainer] + + def __init__(self, ap: app.Application): + self.ap = ap + self.loader = legacy.PluginLoader(ap) + self.installer = github.GitHubRepoInstaller(ap) + self.setting = setting.SettingManager(ap) + self.api_host = context.APIHost(ap) + self.plugins = [] + + async def initialize(self): + await self.loader.initialize() + await self.installer.initialize() + await self.setting.initialize() + await self.api_host.initialize() + + setattr(models, 'require_ver', self.api_host.require_ver) + + async def load_plugins(self): + self.plugins = await self.loader.load_plugins() + + await self.setting.sync_setting(self.plugins) + + # 按优先级倒序 + self.plugins.sort(key=lambda x: x.priority, reverse=True) + + async def initialize_plugins(self): + for plugin in self.plugins: + try: + plugin.plugin_inst = plugin.plugin_class(self.api_host) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') + self.ap.logger.exception(e) + continue + + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + await self.installer.install_plugin(plugin_source) + + await self.ap.ctr_mgr.plugin.post_install_record( + { + "name": "unknown", + "remote": plugin_source, + "author": "unknown", + "version": "HEAD" + } + ) + + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + await self.installer.uninstall_plugin(plugin_name) + + plugin_container = self.get_plugin_by_name(plugin_name) + + await self.ap.ctr_mgr.plugin.post_remove_record( + { + "name": plugin_name, + "remote": plugin_container.plugin_source, + "author": plugin_container.plugin_author, + "version": plugin_container.plugin_version + } + ) + + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + await self.installer.update_plugin(plugin_name, plugin_source) + + plugin_container = self.get_plugin_by_name(plugin_name) + + await self.ap.ctr_mgr.plugin.post_update_record( + plugin={ + "name": plugin_name, + "remote": plugin_container.plugin_source, + "author": plugin_container.plugin_author, + "version": plugin_container.plugin_version + }, + old_version=plugin_container.plugin_version, + new_version="HEAD" + ) + + def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: + """通过插件名获取插件 + """ + for plugin in self.plugins: + if plugin.plugin_name == plugin_name: + return plugin + return None + + async def emit_event(self, event: events.BaseEventModel) -> context.EventContext: + """触发事件 + """ + + ctx = context.EventContext( + host=self.api_host, + event=event + ) + + emitted_plugins: list[context.RuntimeContainer] = [] + + for plugin in self.plugins: + if plugin.enabled: + if event.__class__ in plugin.event_handlers: + + emitted_plugins.append(plugin) + + is_prevented_default_before_call = ctx.is_prevented_default() + + try: + await plugin.event_handlers[event.__class__]( + plugin.plugin_inst, + ctx + ) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}') + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + if not is_prevented_default_before_call and ctx.is_prevented_default(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') + + if ctx.is_prevented_postorder(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') + break + + for key in ctx.__return_value__.keys(): + if hasattr(ctx.event, key): + setattr(ctx.event, key, ctx.__return_value__[key][0]) + + self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') + + if emitted_plugins: + plugins_info: list[dict] = [ + { + 'name': plugin.plugin_name, + 'remote': plugin.plugin_source, + 'version': plugin.plugin_version, + 'author': plugin.plugin_author + } for plugin in emitted_plugins + ] + + await self.ap.ctr_mgr.usage.post_event_record( + plugins=plugins_info, + event_name=event.__class__.__name__ + ) + + return ctx diff --git a/pkg/plugin/metadata.py b/pkg/plugin/metadata.py deleted file mode 100644 index 51de742e..00000000 --- a/pkg/plugin/metadata.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import shutil -import json -import time - -import dulwich.errors as dulwich_err - -from ..utils import updater - - -def read_metadata_file() -> dict: - # 读取 plugins/metadata.json 文件 - if not os.path.exists('plugins/metadata.json'): - return {} - with open('plugins/metadata.json', 'r') as f: - return json.load(f) - - -def write_metadata_file(metadata: dict): - if not os.path.exists('plugins'): - os.mkdir('plugins') - - with open('plugins/metadata.json', 'w') as f: - json.dump(metadata, f, indent=4, ensure_ascii=False) - - -def do_plugin_git_repo_migrate(): - # 仅在 plugins/metadata.json 不存在时执行 - if os.path.exists('plugins/metadata.json'): - return - - metadata = read_metadata_file() - - # 遍历 plugins 下所有目录,获取目录的git远程地址 - for plugin_name in os.listdir('plugins'): - plugin_path = os.path.join('plugins', plugin_name) - if not os.path.isdir(plugin_path): - continue - - remote_url = None - try: - remote_url = updater.get_remote_url(plugin_path) - except dulwich_err.NotGitRepository: - continue - if remote_url == "https://github.com/RockChinQ/QChatGPT" or remote_url == "https://gitee.com/RockChin/QChatGPT" \ - or remote_url == "" or remote_url is None or remote_url == "http://github.com/RockChinQ/QChatGPT" or remote_url == "http://gitee.com/RockChin/QChatGPT": - continue - - from . import host - - if plugin_name not in metadata: - metadata[plugin_name] = { - 'source': remote_url, - 'install_timestamp': int(time.time()), - 'ref': 'HEAD', - } - - write_metadata_file(metadata) - - -def set_plugin_metadata( - plugin_name: str, - source: str, - install_timestamp: int, - ref: str, -): - metadata = read_metadata_file() - metadata[plugin_name] = { - 'source': source, - 'install_timestamp': install_timestamp, - 'ref': ref, - } - write_metadata_file(metadata) - - -def remove_plugin_metadata(plugin_name: str): - metadata = read_metadata_file() - if plugin_name in metadata: - del metadata[plugin_name] - write_metadata_file(metadata) - - -def get_plugin_metadata(plugin_name: str) -> dict: - metadata = read_metadata_file() - if plugin_name in metadata: - return metadata[plugin_name] - return {} \ No newline at end of file diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index a606612d..972eed11 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -1,299 +1,26 @@ -import logging +from __future__ import annotations -from ..plugin import host -from ..utils import context +import typing -PersonMessageReceived = "person_message_received" -"""收到私聊消息时,在判断是否应该响应前触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - message_chain: mirai.models.message.MessageChain 消息链 -""" +from .context import BasePlugin as Plugin +from .events import * -GroupMessageReceived = "group_message_received" -"""收到群聊消息时,在判断是否应该响应前触发(所有群消息) - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - message_chain: mirai.models.message.MessageChain 消息链 -""" +def register( + name: str, + description: str, + version: str, + author +) -> typing.Callable[[typing.Type[Plugin]], typing.Type[Plugin]]: + pass -PersonNormalMessageReceived = "person_normal_message_received" -"""判断为应该处理的私聊普通消息时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - text_message: str 消息文本 - - returns (optional): - alter: str 修改后的消息文本 - reply: list 回复消息组件列表 -""" -PersonCommandSent = "person_command_sent" -"""判断为应该处理的私聊命令时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - command: str 命令 - params: list[str] 参数列表 - text_message: str 完整命令文本 - is_admin: bool 是否为管理员 - - returns (optional): - alter: str 修改后的完整命令文本 - reply: list 回复消息组件列表 -""" +def on( + event: typing.Type[BaseEventModel] +) -> typing.Callable[[typing.Callable], typing.Callable]: + pass -GroupNormalMessageReceived = "group_normal_message_received" -"""判断为应该处理的群聊普通消息时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - text_message: str 消息文本 - - returns (optional): - alter: str 修改后的消息文本 - reply: list 回复消息组件列表 -""" -GroupCommandSent = "group_command_sent" -"""判断为应该处理的群聊命令时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - command: str 命令 - params: list[str] 参数列表 - text_message: str 完整命令文本 - is_admin: bool 是否为管理员 - - returns (optional): - alter: str 修改后的完整命令文本 - reply: list 回复消息组件列表 -""" - -NormalMessageResponded = "normal_message_responded" -"""获取到对普通消息的文字响应时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - session: pkg.openai.session.Session 会话对象 - prefix: str 回复文字消息的前缀 - response_text: str 响应文本 - finish_reason: str 响应结束原因 - funcs_called: list[str] 此次响应中调用的函数列表 - - returns (optional): - prefix: str 修改后的回复文字消息的前缀 - reply: list 替换回复消息组件列表 -""" - -SessionFirstMessageReceived = "session_first_message_received" -"""会话被第一次交互时触发 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 - default_prompt: str 预设值 -""" - -SessionExplicitReset = "session_reset" -"""会话被用户手动重置时触发,此事件不支持阻止默认行为 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 -""" - -SessionExpired = "session_expired" -"""会话过期时触发 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 - session_expire_time: int 已设置的会话过期时间(秒) -""" - -KeyExceeded = "key_exceeded" -"""api-key超额时触发 - kwargs: - key_name: str 超额的api-key名称 - usage: dict 超额的api-key使用情况 - exceeded_keys: list[str] 超额的api-key列表 -""" - -KeySwitched = "key_switched" -"""api-key超额切换成功时触发,此事件不支持阻止默认行为 - kwargs: - key_name: str 切换成功的api-key名称 - key_list: list[str] api-key列表 -""" - -PromptPreProcessing = "prompt_pre_processing" -"""每回合调用接口前对prompt进行预处理时触发,此事件不支持阻止默认行为 - kwargs: - session_name: str 会话名称(_) - default_prompt: list 此session使用的情景预设内容 - prompt: list 此session现有的prompt内容 - text_message: str 用户发送的消息文本 - - returns (optional): - default_prompt: list 修改后的情景预设内容 - prompt: list 修改后的prompt内容 - text_message: str 修改后的消息文本 -""" - - -def on(*args, **kwargs): - """注册事件监听器 - """ - return Plugin.on(*args, **kwargs) - -def func(*args, **kwargs): - """注册内容函数,声明此函数为一个内容函数,在对话中将发送此函数给GPT以供其调用 - 此函数可以具有任意的参数,但必须按照[此文档](https://github.com/RockChinQ/CallingGPT/wiki/1.-Function-Format#function-format) - 所述的格式编写函数的docstring。 - 此功能仅支持在使用gpt-3.5或gpt-4系列模型时使用。 - """ - return Plugin.func(*args, **kwargs) - - -__current_registering_plugin__ = "" - - -def require_ver(ge: str, le: str="v999.9.9") -> bool: - """插件版本要求装饰器 - - Args: - ge (str): 最低版本要求 - le (str, optional): 最高版本要求 - - Returns: - bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 - """ - qchatgpt_version = "" - - from pkg.utils.updater import get_current_tag, compare_version_str - - try: - qchatgpt_version = get_current_tag() # 从updater模块获取版本号 - except: - return False - - if compare_version_str(qchatgpt_version, ge) < 0 or \ - (compare_version_str(qchatgpt_version, le) > 0): - raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, qchatgpt_version)) - - return True - - -class Plugin: - """插件基类""" - - host: host.PluginHost - """插件宿主,提供插件的一些基础功能""" - - @classmethod - def on(cls, event): - """事件处理器装饰器 - - :param - event: 事件类型 - :return: - None - """ - global __current_registering_plugin__ - - def wrapper(func): - plugin_hooks = host.__plugins__[__current_registering_plugin__]["hooks"] - - if event not in plugin_hooks: - plugin_hooks[event] = [] - plugin_hooks[event].append(func) - - # print("registering hook: p='{}', e='{}', f={}".format(__current_registering_plugin__, event, func)) - - host.__plugins__[__current_registering_plugin__]["hooks"] = plugin_hooks - - return func - - return wrapper - - @classmethod - def func(cls, name: str=None): - """内容函数装饰器 - """ - global __current_registering_plugin__ - from CallingGPT.entities.namespace import get_func_schema - - def wrapper(func): - - function_schema = get_func_schema(func) - function_schema['name'] = __current_registering_plugin__ + '-' + (func.__name__ if name is None else name) - - function_schema['enabled'] = True - - host.__function_inst_map__[function_schema['name']] = function_schema['function'] - - del function_schema['function'] - - # logging.debug("registering content function: p='{}', f='{}', s={}".format(__current_registering_plugin__, func, function_schema)) - - host.__callable_functions__.append( - function_schema - ) - - return func - - return wrapper - - -def register(name: str, description: str, version: str, author: str): - """注册插件, 此函数作为装饰器使用 - - Args: - name (str): 插件名称 - description (str): 插件描述 - version (str): 插件版本 - author (str): 插件作者 - - Returns: - None - """ - global __current_registering_plugin__ - - __current_registering_plugin__ = name - # print("registering plugin: n='{}', d='{}', v={}, a='{}'".format(name, description, version, author)) - host.__plugins__[name] = { - "name": name, - "description": description, - "version": version, - "author": author, - "hooks": {}, - "path": host.__current_module_path__, - "enabled": True, - "instance": None, - } - - def wrapper(cls: Plugin): - cls.name = name - cls.description = description - cls.version = version - cls.author = author - cls.host = context.get_plugin_host() - cls.enabled = True - cls.path = host.__current_module_path__ - - # 存到插件列表 - host.__plugins__[name]["class"] = cls - - logging.info("插件注册完成: n='{}', d='{}', v={}, a='{}' ({})".format(name, description, version, author, cls)) - - return cls - - return wrapper +def func( + name: str=None, +) -> typing.Callable: + pass diff --git a/pkg/plugin/setting.py b/pkg/plugin/setting.py new file mode 100644 index 00000000..1ffc0009 --- /dev/null +++ b/pkg/plugin/setting.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from ..core import app +from ..config import manager as cfg_mgr +from . import context + + +class SettingManager: + + ap: app.Application + + settings: cfg_mgr.ConfigManager + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + self.settings = await cfg_mgr.load_json_config( + 'plugins/plugins.json', + 'templates/plugin-settings.json' + ) + + async def sync_setting( + self, + plugin_containers: list[context.RuntimeContainer], + ): + """同步设置 + """ + + not_matched_source_record = [] + + for value in self.settings.data['plugins']: + + if 'name' not in value: # 只有远程地址的,应用到pkg_path相同的插件容器上 + matched = False + + for plugin_container in plugin_containers: + if plugin_container.pkg_path == value['pkg_path']: + matched = True + + plugin_container.plugin_source = value['source'] + break + + if not matched: + not_matched_source_record.append(value) + else: # 正常的插件设置 + for plugin_container in plugin_containers: + if plugin_container.plugin_name == value['name']: + plugin_container.set_from_setting_dict(value) + + self.settings.data = { + 'plugins': [ + p.to_setting_dict() + for p in plugin_containers + ] + } + + self.settings.data['plugins'].extend(not_matched_source_record) + + await self.settings.dump_config() + + async def dump_container_setting( + self, + plugin_containers: list[context.RuntimeContainer] + ): + """保存插件容器设置 + """ + + for plugin in plugin_containers: + for ps in self.settings.data['plugins']: + if ps['name'] == plugin.plugin_name: + plugin_dict = plugin.to_setting_dict() + + for key in plugin_dict: + ps[key] = plugin_dict[key] + + break + + await self.settings.dump_config() + + async def record_installed_plugin_source( + self, + pkg_path: str, + source: str + ): + found = False + + for value in self.settings.data['plugins']: + if value['pkg_path'] == pkg_path: + value['source'] = source + found = True + break + + if not found: + + self.settings.data['plugins'].append( + { + 'pkg_path': pkg_path, + 'source': source + } + ) + await self.settings.dump_config() \ No newline at end of file diff --git a/pkg/plugin/settings.py b/pkg/plugin/settings.py deleted file mode 100644 index 6824906a..00000000 --- a/pkg/plugin/settings.py +++ /dev/null @@ -1,103 +0,0 @@ -import json -import os - -import logging - -from ..plugin import host - -def wrapper_dict_from_runtime_context() -> dict: - """从变量中包装settings.json的数据字典""" - settings = { - "order": [], - "functions": { - "enabled": host.__enable_content_functions__ - } - } - - for plugin_name in host.__plugins_order__: - settings["order"].append(plugin_name) - - return settings - - -def apply_settings(settings: dict): - """将settings.json数据应用到变量中""" - if "order" in settings: - host.__plugins_order__ = settings["order"] - - if "functions" in settings: - if "enabled" in settings["functions"]: - host.__enable_content_functions__ = settings["functions"]["enabled"] - # logging.debug("set content function enabled: {}".format(host.__enable_content_functions__)) - - -def dump_settings(): - """保存settings.json数据""" - logging.debug("保存plugins/settings.json数据") - - settings = wrapper_dict_from_runtime_context() - - with open("plugins/settings.json", "w", encoding="utf-8") as f: - json.dump(settings, f, indent=4, ensure_ascii=False) - - -def load_settings(): - """加载settings.json数据""" - logging.debug("加载plugins/settings.json数据") - - # 读取plugins/settings.json - settings = { - } - - # 检查文件是否存在 - if not os.path.exists("plugins/settings.json"): - # 不存在则创建 - with open("plugins/settings.json", "w", encoding="utf-8") as f: - json.dump(wrapper_dict_from_runtime_context(), f, indent=4, ensure_ascii=False) - - with open("plugins/settings.json", "r", encoding="utf-8") as f: - settings = json.load(f) - - if settings is None: - settings = { - } - - # 检查每个设置项 - if "order" not in settings: - settings["order"] = [] - - settings_modified = False - - settings_copy = settings.copy() - - # 检查settings中多余的插件项 - - # order - for plugin_name in settings_copy["order"]: - if plugin_name not in host.__plugins_order__: - settings["order"].remove(plugin_name) - settings_modified = True - - # 检查settings中缺少的插件项 - - # order - for plugin_name in host.__plugins_order__: - if plugin_name not in settings_copy["order"]: - settings["order"].append(plugin_name) - settings_modified = True - - if "functions" not in settings: - settings["functions"] = { - "enabled": host.__enable_content_functions__ - } - settings_modified = True - elif "enabled" not in settings["functions"]: - settings["functions"]["enabled"] = host.__enable_content_functions__ - settings_modified = True - - logging.info("已全局{}内容函数。".format("启用" if settings["functions"]["enabled"] else "禁用")) - - apply_settings(settings) - - if settings_modified: - dump_settings() diff --git a/pkg/plugin/switch.py b/pkg/plugin/switch.py deleted file mode 100644 index ccc96c8c..00000000 --- a/pkg/plugin/switch.py +++ /dev/null @@ -1,94 +0,0 @@ -# 控制插件的开关 -import json -import logging -import os - -from ..plugin import host - - -def wrapper_dict_from_plugin_list() -> dict: - """将插件列表转换为开关json""" - switch = {} - - for plugin_name in host.__plugins__: - plugin = host.__plugins__[plugin_name] - - switch[plugin_name] = { - "path": plugin["path"], - "enabled": plugin["enabled"], - } - - return switch - - -def apply_switch(switch: dict): - """将开关数据应用到插件列表中""" - # print("将开关数据应用到插件列表中") - # print(switch) - for plugin_name in switch: - host.__plugins__[plugin_name]["enabled"] = switch[plugin_name]["enabled"] - - # 查找此插件的所有内容函数 - for func in host.__callable_functions__: - if func['name'].startswith(plugin_name + '-'): - func['enabled'] = switch[plugin_name]["enabled"] - - -def dump_switch(): - """保存开关数据""" - logging.debug("保存开关数据") - # 将开关数据写入plugins/switch.json - - switch = wrapper_dict_from_plugin_list() - - with open("plugins/switch.json", "w", encoding="utf-8") as f: - json.dump(switch, f, indent=4, ensure_ascii=False) - - -def load_switch(): - """加载开关数据""" - logging.debug("加载开关数据") - # 读取plugins/switch.json - - switch = {} - - # 检查文件是否存在 - if not os.path.exists("plugins/switch.json"): - # 不存在则创建 - with open("plugins/switch.json", "w", encoding="utf-8") as f: - json.dump(switch, f, indent=4, ensure_ascii=False) - - with open("plugins/switch.json", "r", encoding="utf-8") as f: - switch = json.load(f) - - if switch is None: - switch = {} - - switch_modified = False - - switch_copy = switch.copy() - # 检查switch中多余的和path不相符的 - for plugin_name in switch_copy: - if plugin_name not in host.__plugins__: - del switch[plugin_name] - switch_modified = True - elif switch[plugin_name]["path"] != host.__plugins__[plugin_name]["path"]: - # 删除此不相符的 - del switch[plugin_name] - switch_modified = True - - # 检查plugin中多余的 - for plugin_name in host.__plugins__: - if plugin_name not in switch: - switch[plugin_name] = { - "path": host.__plugins__[plugin_name]["path"], - "enabled": host.__plugins__[plugin_name]["enabled"], - } - switch_modified = True - - # 应用开关数据 - apply_switch(switch) - - # 如果switch有修改,保存 - if switch_modified: - dump_switch() diff --git a/pkg/openai/__init__.py b/pkg/provider/__init__.py similarity index 100% rename from pkg/openai/__init__.py rename to pkg/provider/__init__.py diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py new file mode 100644 index 00000000..44866e2e --- /dev/null +++ b/pkg/provider/entities.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import typing +import enum +import pydantic + + +class FunctionCall(pydantic.BaseModel): + name: str + + arguments: str + + +class ToolCall(pydantic.BaseModel): + id: str + + type: str + + function: FunctionCall + + +class Message(pydantic.BaseModel): + role: str # user, system, assistant, tool, command + + name: typing.Optional[str] = None + + content: typing.Optional[str] = None + + function_call: typing.Optional[FunctionCall] = None + + tool_calls: typing.Optional[list[ToolCall]] = None + + tool_call_id: typing.Optional[str] = None diff --git a/pkg/provider/requester/__init__.py b/pkg/provider/requester/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/requester/api.py b/pkg/provider/requester/api.py new file mode 100644 index 00000000..88ba78cd --- /dev/null +++ b/pkg/provider/requester/api.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import abc +import typing + +from ...core import app +from ...core import entities as core_entities +from .. import entities as llm_entities + +class LLMAPIRequester(metaclass=abc.ABCMeta): + """LLM API请求器 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def request( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求 + """ + raise NotImplementedError diff --git a/pkg/provider/requester/apis/__init__.py b/pkg/provider/requester/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/requester/apis/chatcmpl.py new file mode 100644 index 00000000..c41b50d6 --- /dev/null +++ b/pkg/provider/requester/apis/chatcmpl.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import asyncio +import typing +import json +from typing import AsyncGenerator + +import openai +import openai.types.chat.chat_completion as chat_completion +import httpx + +from pkg.provider.entities import Message + +from .. import api, entities, errors +from ....core import entities as core_entities +from ... import entities as llm_entities +from ...tools import entities as tools_entities + + +class OpenAIChatCompletion(api.LLMAPIRequester): + client: openai.AsyncClient + + async def initialize(self): + self.client = openai.AsyncClient( + api_key="", + base_url=self.ap.provider_cfg.data['openai-config']['base_url'], + timeout=self.ap.provider_cfg.data['openai-config']['request-timeout'], + http_client=httpx.AsyncClient( + proxies=self.ap.proxy_mgr.get_forward_proxies() + ) + ) + + async def _req( + self, + args: dict, + ) -> chat_completion.ChatCompletion: + self.ap.logger.debug(f"req chat_completion with args {args}") + return await self.client.chat.completions.create(**args) + + async def _make_msg( + self, + chat_completion: chat_completion.ChatCompletion, + ) -> llm_entities.Message: + chatcmpl_message = chat_completion.choices[0].message.dict() + + message = llm_entities.Message(**chatcmpl_message) + + return message + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.ap.provider_cfg.data['openai-config']['chat-completions-params'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message + + async def _request( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求""" + + pending_tool_calls = [] + + req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 + m.dict(exclude_none=True) for m in query.prompt.messages + ] + [m.dict(exclude_none=True) for m in query.messages] + + # req_messages.append({"role": "user", "content": str(query.message_chain)}) + + msg = await self._closure(req_messages, query.use_model, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) + + while pending_tool_calls: + for tool_call in pending_tool_calls: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg.dict(exclude_none=True)) + + # 处理完所有调用,继续请求 + msg = await self._closure(req_messages, query.use_model, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) + + async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]: + try: + async for msg in self._request(query): + yield msg + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/requester/entities.py b/pkg/provider/requester/entities.py new file mode 100644 index 00000000..d4c51d6f --- /dev/null +++ b/pkg/provider/requester/entities.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import typing + +import pydantic + +from . import api +from . import token, tokenizer + + +class LLMModelInfo(pydantic.BaseModel): + """模型""" + + name: str + + model_name: typing.Optional[str] = None + + token_mgr: token.TokenManager + + requester: api.LLMAPIRequester + + tokenizer: 'tokenizer.LLMTokenizer' + + tool_call_supported: typing.Optional[bool] = False + + max_tokens: typing.Optional[int] = 2048 + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/provider/requester/errors.py b/pkg/provider/requester/errors.py new file mode 100644 index 00000000..4feddeab --- /dev/null +++ b/pkg/provider/requester/errors.py @@ -0,0 +1,5 @@ +class RequesterError(Exception): + """Base class for all Requester errors.""" + + def __init__(self, message: str): + super().__init__("模型请求失败: "+message) \ No newline at end of file diff --git a/pkg/provider/requester/modelmgr.py b/pkg/provider/requester/modelmgr.py new file mode 100644 index 00000000..b197c9ca --- /dev/null +++ b/pkg/provider/requester/modelmgr.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from . import entities +from ...core import app + +from .apis import chatcmpl +from . import token +from .tokenizers import tiktoken + + +class ModelManager: + + ap: app.Application + + model_list: list[entities.LLMModelInfo] + + def __init__(self, ap: app.Application): + self.ap = ap + self.model_list = [] + + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: + """通过名称获取模型 + """ + for model in self.model_list: + if model.name == name: + return model + raise ValueError(f"不支持模型: {name} , 请检查配置文件") + + async def initialize(self): + openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) + await openai_chat_completion.initialize() + openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys'])) + + tiktoken_tokenizer = tiktoken.Tiktoken(self.ap) + + model_list = [ + entities.LLMModelInfo( + name="gpt-3.5-turbo", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=4096 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-1106", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=16385 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-16k", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=16385 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=4096 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-16k-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=16385 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-0301", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=4096 + ) + ] + + self.model_list.extend(model_list) + + gpt4_model_list = [ + entities.LLMModelInfo( + name="gpt-4-0125-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="gpt-4-turbo-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="gpt-4-1106-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="gpt-4-vision-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="gpt-4", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=8192 + ), + entities.LLMModelInfo( + name="gpt-4-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=8192 + ), + entities.LLMModelInfo( + name="gpt-4-32k", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=32768 + ), + entities.LLMModelInfo( + name="gpt-4-32k-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=32768 + ) + ] + + self.model_list.extend(gpt4_model_list) + + one_api_model_list = [ + entities.LLMModelInfo( + name="OneAPI/SparkDesk", + model_name='SparkDesk', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=8192 + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_pro", + model_name='chatglm_pro', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_std", + model_name='chatglm_std', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_lite", + model_name='chatglm_lite', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=128000 + ), + entities.LLMModelInfo( + name="OneAPI/qwen-v1", + model_name='qwen-v1', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=6000 + ), + entities.LLMModelInfo( + name="OneAPI/qwen-plus-v1", + model_name='qwen-plus-v1', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=30000 + ), + entities.LLMModelInfo( + name="OneAPI/ERNIE-Bot", + model_name='ERNIE-Bot', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=2000 + ), + entities.LLMModelInfo( + name="OneAPI/ERNIE-Bot-turbo", + model_name='ERNIE-Bot-turbo', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=7000 + ), + entities.LLMModelInfo( + name="OneAPI/gemini-pro", + model_name='gemini-pro', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer, + max_tokens=30720 + ), + ] + + self.model_list.extend(one_api_model_list) diff --git a/pkg/provider/requester/token.py b/pkg/provider/requester/token.py new file mode 100644 index 00000000..9277c1a6 --- /dev/null +++ b/pkg/provider/requester/token.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import typing + +import pydantic + + +class TokenManager(): + + provider: str + + tokens: list[str] + + using_token_index: typing.Optional[int] = 0 + + def __init__(self, provider: str, tokens: list[str]): + self.provider = provider + self.tokens = tokens + self.using_token_index = 0 + + def get_token(self) -> str: + return self.tokens[self.using_token_index] + + def next_token(self): + self.using_token_index = (self.using_token_index + 1) % len(self.tokens) diff --git a/pkg/provider/requester/tokenizer.py b/pkg/provider/requester/tokenizer.py new file mode 100644 index 00000000..5af8a733 --- /dev/null +++ b/pkg/provider/requester/tokenizer.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import abc +import typing + +from ...core import app +from .. import entities as llm_entities +from . import entities + + +class LLMTokenizer(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化分词器 + """ + pass + + @abc.abstractmethod + async def count_token( + self, + messages: list[llm_entities.Message], + model: entities.LLMModelInfo + ) -> int: + pass diff --git a/pkg/provider/requester/tokenizers/__init__.py b/pkg/provider/requester/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/requester/tokenizers/tiktoken.py b/pkg/provider/requester/tokenizers/tiktoken.py new file mode 100644 index 00000000..0bf97b17 --- /dev/null +++ b/pkg/provider/requester/tokenizers/tiktoken.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import tiktoken + +from .. import tokenizer +from ... import entities as llm_entities +from .. import entities + + +class Tiktoken(tokenizer.LLMTokenizer): + + async def count_token( + self, + messages: list[llm_entities.Message], + model: entities.LLMModelInfo + ) -> int: + try: + encoding = tiktoken.encoding_for_model(model.name) + except KeyError: + # print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + num_tokens += len(encoding.encode(message.role)) + num_tokens += len(encoding.encode(message.content if message.content is not None else '')) + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens diff --git a/pkg/provider/session/__init__.py b/pkg/provider/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py new file mode 100644 index 00000000..a7812504 --- /dev/null +++ b/pkg/provider/session/sessionmgr.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import asyncio + +from ...core import app, entities as core_entities + + +class SessionManager: + + ap: app.Application + + session_list: list[core_entities.Session] + + def __init__(self, ap: app.Application): + self.ap = ap + self.session_list = [] + + async def initialize(self): + pass + + async def get_session(self, query: core_entities.Query) -> core_entities.Session: + """获取会话 + """ + for session in self.session_list: + if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: + return session + + session_concurrency = self.ap.system_cfg.data['session-concurrency']['default'] + + if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['session-concurrency']: + session_concurrency = self.ap.system_cfg.data['session-concurrency'][f'{query.launcher_type.value}_{query.launcher_id}'] + + session = core_entities.Session( + launcher_type=query.launcher_type, + launcher_id=query.launcher_id, + semaphore=asyncio.Semaphore(session_concurrency), + ) + self.session_list.append(session) + return session + + async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation: + if not session.conversations: + session.conversations = [] + + if session.using_conversation is None: + conversation = core_entities.Conversation( + prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), + messages=[], + use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['openai-config']['chat-completions-params']['model']), + use_funcs=await self.ap.tool_mgr.get_all_functions(), + ) + session.conversations.append(conversation) + session.using_conversation = conversation + + return session.using_conversation diff --git a/pkg/provider/sysprompt/__init__.py b/pkg/provider/sysprompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/sysprompt/entities.py b/pkg/provider/sysprompt/entities.py new file mode 100644 index 00000000..31ca199a --- /dev/null +++ b/pkg/provider/sysprompt/entities.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import typing +import pydantic + +from ...provider import entities + + +class Prompt(pydantic.BaseModel): + """供AI使用的Prompt""" + + name: str + + messages: list[entities.Message] diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py new file mode 100644 index 00000000..ca9e8730 --- /dev/null +++ b/pkg/provider/sysprompt/loader.py @@ -0,0 +1,32 @@ +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class PromptLoader(metaclass=abc.ABCMeta): + """Prompt加载器抽象类 + """ + + ap: app.Application + + prompts: list[entities.Prompt] + + def __init__(self, ap: app.Application): + self.ap = ap + self.prompts = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def load(self): + """加载Prompt + """ + raise NotImplementedError + + def get_prompts(self) -> list[entities.Prompt]: + """获取Prompt列表 + """ + return self.prompts diff --git a/pkg/provider/sysprompt/loaders/__init__.py b/pkg/provider/sysprompt/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py new file mode 100644 index 00000000..a559ff73 --- /dev/null +++ b/pkg/provider/sysprompt/loaders/scenario.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import json +import os + +from .. import loader +from .. import entities +from ....provider import entities as llm_entities + + +class ScenarioPromptLoader(loader.PromptLoader): + """加载scenario目录下的json""" + + async def load(self): + """加载Prompt + """ + for file in os.listdir("data/scenarios"): + with open("data/scenarios/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + file_json = json.loads(file_str) + messages = [] + for msg in file_json["prompt"]: + role = 'system' + if "role" in msg: + role = msg['role'] + messages.append( + llm_entities.Message( + role=role, + content=msg['content'], + ) + ) + prompt = entities.Prompt( + name=file_name, + messages=messages + ) + self.prompts.append(prompt) + \ No newline at end of file diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py new file mode 100644 index 00000000..57e06ed2 --- /dev/null +++ b/pkg/provider/sysprompt/loaders/single.py @@ -0,0 +1,42 @@ +from __future__ import annotations +import os + +from .. import loader +from .. import entities +from ....provider import entities as llm_entities + + +class SingleSystemPromptLoader(loader.PromptLoader): + """配置文件中的单条system prompt的prompt加载器 + """ + + async def load(self): + """加载Prompt + """ + + for name, cnt in self.ap.provider_cfg.data['prompt'].items(): + prompt = entities.Prompt( + name=name, + messages=[ + llm_entities.Message( + role='system', + content=cnt + ) + ] + ) + self.prompts.append(prompt) + + for file in os.listdir("data/prompts"): + with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + prompt = entities.Prompt( + name=file_name, + messages=[ + llm_entities.Message( + role='system', + content=file_str + ) + ] + ) + self.prompts.append(prompt) diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py new file mode 100644 index 00000000..5500bb10 --- /dev/null +++ b/pkg/provider/sysprompt/sysprompt.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from ...core import app +from . import loader +from .loaders import single, scenario + + +class PromptManager: + + ap: app.Application + + loader_inst: loader.PromptLoader + + default_prompt: str = 'default' + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + + loader_map = { + "normal": single.SingleSystemPromptLoader, + "full_scenario": scenario.ScenarioPromptLoader + } + + loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] + + self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + + await self.loader_inst.initialize() + await self.loader_inst.load() + + def get_all_prompts(self) -> list[loader.entities.Prompt]: + """获取所有Prompt + """ + return self.loader_inst.get_prompts() + + async def get_prompt(self, name: str) -> loader.entities.Prompt: + """获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name == name: + return prompt + + async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt: + """通过前缀获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name.startswith(prefix): + return prompt diff --git a/pkg/provider/tools/__init__.py b/pkg/provider/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py new file mode 100644 index 00000000..52867291 --- /dev/null +++ b/pkg/provider/tools/entities.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import abc +import typing +import asyncio + +import pydantic + +from ...core import entities as core_entities + + +class LLMFunction(pydantic.BaseModel): + """函数""" + + name: str + """函数名""" + + human_desc: str + + description: str + """给LLM识别的函数描述""" + + enable: typing.Optional[bool] = True + + parameters: dict + + func: typing.Callable + """供调用的python异步方法 + + 此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。 + query参数不在parameters中,但在调用时会自动传入。 + 但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中, + 对插件的内容函数进行封装并存到这里来。 + """ + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py new file mode 100644 index 00000000..72c892bb --- /dev/null +++ b/pkg/provider/tools/toolmgr.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import typing +import traceback + +from ...core import app, entities as core_entities +from . import entities + + +class ToolManager: + """LLM工具管理器 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + self.all_functions = [] + + async def initialize(self): + pass + + async def get_function(self, name: str) -> entities.LLMFunction: + """获取函数 + """ + for function in await self.get_all_functions(): + if function.name == name: + return function + return None + + async def get_all_functions(self) -> list[entities.LLMFunction]: + """获取所有函数 + """ + all_functions: list[entities.LLMFunction] = [] + + for plugin in self.ap.plugin_mgr.plugins: + all_functions.extend(plugin.content_functions) + + return all_functions + + async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str: + """生成函数列表 + """ + tools = [] + + for function in use_funcs: + if function.enable: + function_schema = { + "type": "function", + "function": { + "name": function.name, + "description": function.description, + "parameters": function.parameters + } + } + tools.append(function_schema) + + return tools + + async def execute_func_call( + self, + query: core_entities.Query, + name: str, + parameters: dict + ) -> typing.Any: + """执行函数调用 + """ + + try: + + function = await self.get_function(name) + if function is None: + return None + + parameters = parameters.copy() + + parameters = { + "query": query, + **parameters + } + + return await function.func(**parameters) + except Exception as e: + self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') + traceback.print_exc() + return f'error occurred when executing function {name}: {e}' + finally: + + plugin = None + + for p in self.ap.plugin_mgr.plugins: + if function in p.content_functions: + plugin = p + break + + if plugin is not None: + + await self.ap.ctr_mgr.usage.post_function_record( + plugin={ + 'name': plugin.plugin_name, + 'remote': plugin.plugin_source, + 'version': plugin.plugin_version, + 'author': plugin.plugin_author + }, + function_name=function.name, + function_description=function.description, + ) \ No newline at end of file diff --git a/pkg/qqbot/banlist.py b/pkg/qqbot/banlist.py deleted file mode 100644 index 949c541b..00000000 --- a/pkg/qqbot/banlist.py +++ /dev/null @@ -1,50 +0,0 @@ -from ..utils import context - - -def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool: - if not context.get_qqbot_manager().enable_banlist: - return False - - result = False - - if launcher_type == 'group': - # 检查是否显式声明发起人QQ要被person忽略 - if sender_id in context.get_qqbot_manager().ban_person: - result = True - else: - for group_rule in context.get_qqbot_manager().ban_group: - if type(group_rule) == int: - if group_rule == launcher_id: # 此群群号被禁用 - result = True - elif type(group_rule) == str: - if group_rule.startswith('!'): - # 截取!后面的字符串作为表达式,判断是否匹配 - reg_str = group_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别 - result = False - break - else: - # 判断是否匹配regexp - import re - if re.match(group_rule, str(launcher_id)): # 此群群号被禁用 - result = True - - else: - # ban_person, 与群规则相同 - for person_rule in context.get_qqbot_manager().ban_person: - if type(person_rule) == int: - if person_rule == launcher_id: - result = True - elif type(person_rule) == str: - if person_rule.startswith('!'): - reg_str = person_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): - result = False - break - else: - import re - if re.match(person_rule, str(launcher_id)): - result = True - return result diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py deleted file mode 100644 index d8373cd8..00000000 --- a/pkg/qqbot/blob.py +++ /dev/null @@ -1,100 +0,0 @@ -# 长消息处理相关 -import os -import time -import base64 -import typing - -from mirai.models.message import MessageComponent, MessageChain, Image -from mirai.models.message import ForwardMessageNode -from mirai.models.base import MiraiBaseModel - -from ..utils import text2img -from ..utils import context - - -class ForwardMessageDiaplay(MiraiBaseModel): - title: str = "群聊的聊天记录" - brief: str = "[聊天记录]" - source: str = "聊天记录" - preview: typing.List[str] = [] - summary: str = "查看x条转发消息" - - -class Forward(MessageComponent): - """合并转发。""" - type: str = "Forward" - """消息组件类型。""" - display: ForwardMessageDiaplay - """显示信息""" - node_list: typing.List[ForwardMessageNode] - """转发消息节点列表。""" - def __init__(self, *args, **kwargs): - if len(args) == 1: - self.node_list = args[0] - super().__init__(**kwargs) - super().__init__(*args, **kwargs) - - def __str__(self): - return '[聊天记录]' - - -def text_to_image(text: str) -> MessageComponent: - """将文本转换成图片""" - # 检查temp文件夹是否存在 - if not os.path.exists('temp'): - os.mkdir('temp') - img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time()))) - - compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time()))) - # 读取图片,转换成base64 - with open(compressed_path, 'rb') as f: - img = f.read() - - b64 = base64.b64encode(img) - - # 删除图片 - os.remove(img_path) - - # 判断compressed_path是否存在 - if os.path.exists(compressed_path): - os.remove(compressed_path) - # 返回图片 - return Image(base64=b64.decode('utf-8')) - - -def check_text(text: str) -> list: - """检查文本是否为长消息,并转换成该使用的消息链组件""" - - config = context.get_config_manager().data - - if len(text) > config['blob_message_threshold']: - - # logging.info("长消息: {}".format(text)) - if config['blob_message_strategy'] == 'image': - # 转换成图片 - return [text_to_image(text)] - elif config['blob_message_strategy'] == 'forward': - - # 包装转发消息 - display = ForwardMessageDiaplay( - title='群聊的聊天记录', - brief='[聊天记录]', - source='聊天记录', - preview=["bot: "+text], - summary="查看1条转发消息" - ) - - node = ForwardMessageNode( - sender_id=config['mirai_http_api_config']['qq'], - sender_name='bot', - message_chain=MessageChain([text]) - ) - - forward = Forward( - display=display, - node_list=[node] - ) - - return [forward] - else: - return [text] \ No newline at end of file diff --git a/pkg/qqbot/cmds/aamgr.py b/pkg/qqbot/cmds/aamgr.py deleted file mode 100644 index 6bc5c2de..00000000 --- a/pkg/qqbot/cmds/aamgr.py +++ /dev/null @@ -1,333 +0,0 @@ -import logging -import copy -import pkgutil -import traceback -import json - -import tips as tips_custom - - -__command_list__ = {} -"""命令树 - -结构: -{ - 'cmd1': { - 'description': 'cmd1 description', - 'usage': 'cmd1 usage', - 'aliases': ['cmd1 alias1', 'cmd1 alias2'], - 'privilege': 0, - 'parent': None, - 'cls': , - 'sub': [ - 'cmd1-1' - ] - }, - 'cmd1.cmd1-1: { - 'description': 'cmd1-1 description', - 'usage': 'cmd1-1 usage', - 'aliases': ['cmd1-1 alias1', 'cmd1-1 alias2'], - 'privilege': 0, - 'parent': 'cmd1', - 'cls': , - 'sub': [] - }, - 'cmd2': { - 'description': 'cmd2 description', - 'usage': 'cmd2 usage', - 'aliases': ['cmd2 alias1', 'cmd2 alias2'], - 'privilege': 0, - 'parent': None, - 'cls': , - 'sub': [ - 'cmd2-1' - ] - }, - 'cmd2.cmd2-1': { - 'description': 'cmd2-1 description', - 'usage': 'cmd2-1 usage', - 'aliases': ['cmd2-1 alias1', 'cmd2-1 alias2'], - 'privilege': 0, - 'parent': 'cmd2', - 'cls': , - 'sub': [ - 'cmd2-1-1' - ] - }, - 'cmd2.cmd2-1.cmd2-1-1': { - 'description': 'cmd2-1-1 description', - 'usage': 'cmd2-1-1 usage', - 'aliases': ['cmd2-1-1 alias1', 'cmd2-1-1 alias2'], - 'privilege': 0, - 'parent': 'cmd2.cmd2-1', - 'cls': , - 'sub': [] - }, -} -""" - -__tree_index__: dict[str, list] = {} -"""命令树索引 - -结构: -{ - 'pkg.qqbot.cmds.cmd1.CommandCmd1': 'cmd1', # 顶级命令 - 'pkg.qqbot.cmds.cmd1.CommandCmd1_1': 'cmd1.cmd1-1', # 类名: 节点路径 - 'pkg.qqbot.cmds.cmd2.CommandCmd2': 'cmd2', - 'pkg.qqbot.cmds.cmd2.CommandCmd2_1': 'cmd2.cmd2-1', - 'pkg.qqbot.cmds.cmd2.CommandCmd2_1_1': 'cmd2.cmd2-1.cmd2-1-1', -} -""" - - -class Context: - """命令执行上下文""" - command: str - """顶级命令文本""" - - crt_command: str - """当前子命令文本""" - - params: list - """完整参数列表""" - - crt_params: list - """当前子命令参数列表""" - - session_name: str - """会话名""" - - text_message: str - """命令完整文本""" - - launcher_type: str - """命令发起者类型""" - - launcher_id: int - """命令发起者ID""" - - sender_id: int - """命令发送者ID""" - - is_admin: bool - """[过时]命令发送者是否为管理员""" - - privilege: int - """命令发送者权限等级""" - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -class AbstractCommandNode: - """命令抽象类""" - - parent: type - """父命令类""" - - name: str - """命令名""" - - description: str - """命令描述""" - - usage: str - """命令用法""" - - aliases: list[str] - """命令别名""" - - privilege: int - """命令权限等级, 权限大于等于此值的用户才能执行命令""" - - @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: - """命令处理函数 - - :param ctx: 命令执行上下文 - - :return: (是否执行, 回复列表(若执行)) - - 若未执行,将自动以下一个参数查找并执行子命令 - """ - raise NotImplementedError - - @classmethod - def help(cls) -> str: - """获取命令帮助信息""" - return '命令: {}\n描述: {}\n用法: \n{}\n别名: {}\n权限: {}'.format( - cls.name, - cls.description, - cls.usage, - ', '.join(cls.aliases), - cls.privilege - ) - - @staticmethod - def register( - parent: type = None, - name: str = None, - description: str = None, - usage: str = None, - aliases: list[str] = None, - privilege: int = 0 - ): - """注册命令 - - :param cls: 命令类 - :param name: 命令名 - :param parent: 父命令类 - """ - global __command_list__, __tree_index__ - - def wrapper(cls): - cls.name = name - cls.parent = parent - cls.description = description - cls.usage = usage - cls.aliases = aliases - cls.privilege = privilege - - logging.debug("cls: {}, name: {}, parent: {}".format(cls, name, parent)) - - if parent is None: - # 顶级命令注册 - __command_list__[name] = { - 'description': cls.description, - 'usage': cls.usage, - 'aliases': cls.aliases, - 'privilege': cls.privilege, - 'parent': None, - 'cls': cls, - 'sub': [] - } - # 更新索引 - __tree_index__[cls.__module__ + '.' + cls.__name__] = name - else: - # 获取父节点名称 - path = __tree_index__[parent.__module__ + '.' + parent.__name__] - - parent_node = __command_list__[path] - # 链接父子命令 - __command_list__[path]['sub'].append(name) - # 注册子命令 - __command_list__[path + '.' + name] = { - 'description': cls.description, - 'usage': cls.usage, - 'aliases': cls.aliases, - 'privilege': cls.privilege, - 'parent': path, - 'cls': cls, - 'sub': [] - } - # 更新索引 - __tree_index__[cls.__module__ + '.' + cls.__name__] = path + '.' + name - - return cls - - return wrapper - - -class CommandPrivilegeError(Exception): - """命令权限不足或不存在异常""" - pass - - -# 传入Context对象,广搜命令树,返回执行结果 -# 若命令被处理,返回reply列表 -# 若命令未被处理,继续执行下一级命令 -# 若命令不存在,报异常 -def execute(context: Context) -> list: - """执行命令 - - :param ctx: 命令执行上下文 - - :return: 回复列表 - """ - global __command_list__ - - # 拷贝ctx - ctx: Context = copy.deepcopy(context) - - # 从树取出顶级命令 - node = __command_list__ - - path = ctx.command - - while True: - try: - node = __command_list__[path] - logging.debug('执行命令: {}'.format(path)) - - # 检查权限 - if ctx.privilege < node['privilege']: - raise CommandPrivilegeError(tips_custom.command_admin_message+"{}".format(path)) - - # 执行 - execed, reply = node['cls'].process(ctx) - if execed: - return reply - else: - # 删除crt_params第一个参数 - ctx.crt_command = ctx.crt_params.pop(0) - # 下一个path - path = path + '.' + ctx.crt_command - except KeyError: - traceback.print_exc() - raise CommandPrivilegeError(tips_custom.command_err_message+"{}".format(path)) - - -def register_all(): - """启动时调用此函数注册所有命令 - - 递归处理pkg.qqbot.cmds包下及其子包下所有模块的所有继承于AbstractCommand的类 - """ - # 模块:遍历其中的继承于AbstractCommand的类,进行注册 - # 包:递归处理包下的模块 - # 排除__开头的属性 - global __command_list__, __tree_index__ - - import pkg.qqbot.cmds - - def walk(module, prefix, path_prefix): - # 排除不处于pkg.qqbot.cmds中的包 - if not module.__name__.startswith('pkg.qqbot.cmds'): - return - - logging.debug('walk: {}, path: {}'.format(module.__name__, module.__path__)) - for item in pkgutil.iter_modules(module.__path__): - if item.name.startswith('__'): - continue - - if item.ispkg: - walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/') - else: - m = __import__(module.__name__ + '.' + item.name, fromlist=['']) - # for name, cls in inspect.getmembers(m, inspect.isclass): - # # 检查是否为命令类 - # if cls.__module__ == m.__name__ and issubclass(cls, AbstractCommandNode) and cls != AbstractCommandNode: - # cls.register(cls, cls.name, cls.parent) - - walk(pkg.qqbot.cmds, '', '') - logging.debug(__command_list__) - - -def apply_privileges(): - """读取cmdpriv.json并应用命令权限""" - # 读取内容 - json_str = "" - with open('cmdpriv.json', 'r', encoding="utf-8") as f: - json_str = f.read() - - data = json.loads(json_str) - for path, priv in data.items(): - if path == 'comment': - continue - - if path not in __command_list__: - continue - - if __command_list__[path]['privilege'] != priv: - logging.debug('应用权限: {} -> {}(default: {})'.format(path, priv, __command_list__[path]['privilege'])) - - __command_list__[path]['privilege'] = priv diff --git a/pkg/qqbot/cmds/funcs/draw.py b/pkg/qqbot/cmds/funcs/draw.py deleted file mode 100644 index 5ce25ad5..00000000 --- a/pkg/qqbot/cmds/funcs/draw.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging - -import mirai - -from .. import aamgr -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="draw", - description="使用DALL·E生成图片", - usage="!draw <图片提示语>", - aliases=[], - privilege=1 -) -class DrawCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - - reply = [] - - if len(ctx.params) == 0: - reply = ["[bot]err: 未提供图片描述文字"] - else: - session = pkg.openai.session.get_session(ctx.session_name) - - res = session.draw_image(" ".join(ctx.params)) - - logging.debug("draw_image result:{}".format(res)) - reply = [mirai.Image(url=res.data[0].url)] - config = context.get_config_manager().data - if config['include_image_description']: - reply.append(" ".join(ctx.params)) - - return True, reply diff --git a/pkg/qqbot/cmds/funcs/func.py b/pkg/qqbot/cmds/funcs/func.py deleted file mode 100644 index 61675931..00000000 --- a/pkg/qqbot/cmds/funcs/func.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging -import json - -from .. import aamgr - -@aamgr.AbstractCommandNode.register( - parent=None, - name="func", - description="管理内容函数", - usage="!func", - aliases=[], - privilege=1 -) -class FuncCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - from pkg.plugin.models import host - - reply = [] - - reply_str = "当前已加载的内容函数:\n\n" - - logging.debug("host.__callable_functions__: {}".format(json.dumps(host.__callable_functions__, indent=4))) - - index = 1 - for func in host.__callable_functions__: - reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) - index += 1 - - reply = [reply_str] - - return True, reply diff --git a/pkg/qqbot/cmds/plugin/plugin.py b/pkg/qqbot/cmds/plugin/plugin.py deleted file mode 100644 index 5e699bba..00000000 --- a/pkg/qqbot/cmds/plugin/plugin.py +++ /dev/null @@ -1,198 +0,0 @@ -from ....plugin import host as plugin_host -from ....utils import updater -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="plugin", - description="插件管理", - usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>", - aliases=[], - privilege=1 -) -class PluginCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - plugin_list = plugin_host.__plugins__ - if len(ctx.params) == 0: - # 列出所有插件 - - reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__)) - idx = 0 - for key in plugin_host.iter_plugins_name(): - plugin = plugin_list[key] - reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ - .format((idx+1), plugin['name'], - "[已禁用]" if not plugin['enabled'] else "", - plugin['description'], - plugin['version'], plugin['author']) - - if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): - remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1])) - if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT": - reply_str += "源码: "+remote_url+"\n" - - idx += 1 - - reply = [reply_str] - return True, reply - else: - return False, [] - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="get", - description="安装插件", - usage="!plugin get <插件仓库地址>", - aliases=[], - privilege=2 -) -class PluginGetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import threading - import logging - import pkg.utils.context - - if len(ctx.crt_params) == 0: - reply = ["[bot]err: 请提供插件仓库地址"] - return True, reply - - reply = [] - def closure(): - try: - plugin_host.install_plugin(ctx.crt_params[0]) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 命令重载插件") - except Exception as e: - logging.error("插件安装失败:{}".format(e)) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e)) - - threading.Thread(target=closure, args=()).start() - reply = ["[bot]正在安装插件..."] - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="update", - description="更新指定插件或全部插件", - usage="!plugin update", - aliases=[], - privilege=2 -) -class PluginUpdateCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import threading - import logging - plugin_list = plugin_host.__plugins__ - - reply = [] - - if len(ctx.crt_params) > 0: - def closure(): - try: - import pkg.utils.context - - updated = [] - - if ctx.crt_params[0] == 'all': - for key in plugin_list: - plugin_host.update_plugin(key) - updated.append(key) - else: - plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(ctx.crt_params[0]) - - if plugin_path_name is not None: - plugin_host.update_plugin(ctx.crt_params[0]) - updated.append(ctx.crt_params[0]) - else: - raise Exception("未找到插件: {}".format(ctx.crt_params[0])) - - pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}, 请发送 !reload 重载插件".format(", ".join(updated))) - except Exception as e: - logging.error("插件更新失败:{}".format(e)) - pkg.utils.context.get_qqbot_manager().notify_admin("插件更新失败:{} 请使用 !plugin 命令确认插件名称或尝试手动更新插件".format(e)) - - reply = ["[bot]正在更新插件,请勿重复发起..."] - threading.Thread(target=closure).start() - else: - reply = ["[bot]请指定要更新的插件, 或使用 !plugin update all 更新所有插件"] - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="del", - description="删除插件", - usage="!plugin del <插件名>", - aliases=[], - privilege=2 -) -class PluginDelCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - plugin_list = plugin_host.__plugins__ - reply = [] - - if len(ctx.crt_params) < 1: - reply = ["[bot]err: 未指定插件名"] - else: - plugin_name = ctx.crt_params[0] - if plugin_name in plugin_list: - unin_path = plugin_host.uninstall_plugin(plugin_name) - reply = ["[bot]已删除插件: {} ({}), 请发送 !reload 重载插件".format(plugin_name, unin_path)] - else: - reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="on", - description="启用指定插件", - usage="!plugin on <插件名>", - aliases=[], - privilege=2 -) -@aamgr.AbstractCommandNode.register( - parent=PluginCommand, - name="off", - description="禁用指定插件", - usage="!plugin off <插件名>", - aliases=[], - privilege=2 -) -class PluginOnOffCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.plugin.switch as plugin_switch - - plugin_list = plugin_host.__plugins__ - reply = [] - - print(ctx.params) - new_status = ctx.params[0] == 'on' - - if len(ctx.crt_params) < 1: - reply = ["[bot]err: 未指定插件名"] - else: - plugin_name = ctx.crt_params[0] - if plugin_name in plugin_list: - plugin_list[plugin_name]['enabled'] = new_status - - for func in plugin_host.__callable_functions__: - if func['name'].startswith(plugin_name+"-"): - func['enabled'] = new_status - - plugin_switch.dump_switch() - reply = ["[bot]已{}插件: {}".format("启用" if new_status else "禁用", plugin_name)] - else: - reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)] - - return True, reply - diff --git a/pkg/qqbot/cmds/session/default.py b/pkg/qqbot/cmds/session/default.py deleted file mode 100644 index 1a1ff756..00000000 --- a/pkg/qqbot/cmds/session/default.py +++ /dev/null @@ -1,71 +0,0 @@ -from .. import aamgr -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="default", - description="操作情景预设", - usage="!default\n!default set [指定情景预设为默认]", - aliases=[], - privilege=1 -) -class DefaultCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - config = context.get_config_manager().data - - if len(params) == 0: - # 输出目前所有情景预设 - import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode']) - - prompts = dprompt.mode_inst().list() - - for key in prompts: - pro = prompts[key] - reply_str += "名称: {}".format(key) - - for r in pro: - reply_str += "\n - [{}]: {}".format(r['role'], r['content']) - - reply_str += "\n\n" - - reply_str += "\n当前默认情景预设:{}\n".format(dprompt.mode_inst().get_using_name()) - reply_str += "请使用 !default set <情景预设名称> 来设置默认情景预设" - reply = [reply_str] - else: - return False, [] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DefaultCommand, - name="set", - description="设置默认情景预设", - usage="!default set <情景预设名称>", - aliases=[], - privilege=2 -) -class DefaultSetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - - if len(ctx.crt_params) == 0: - reply = ["[bot]err: 请指定情景预设名称"] - elif len(ctx.crt_params) > 0: - import pkg.openai.dprompt as dprompt - try: - full_name = dprompt.mode_inst().set_using_name(ctx.crt_params[0]) - reply = ["[bot]已设置默认情景预设为:{}".format(full_name)] - except Exception as e: - reply = ["[bot]err: {}".format(e)] - - return True, reply diff --git a/pkg/qqbot/cmds/session/del.py b/pkg/qqbot/cmds/session/del.py deleted file mode 100644 index 45fdc4ee..00000000 --- a/pkg/qqbot/cmds/session/del.py +++ /dev/null @@ -1,51 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="del", - description="删除当前会话的历史记录", - usage="!del <序号>\n!del all", - aliases=[], - privilege=1 -) -class DelCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - if len(params) == 0: - reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"] - else: - if params[0] == 'all': - return False, [] - elif params[0].isdigit(): - if pkg.openai.session.get_session(session_name).delete_history(int(params[0])): - reply = ["[bot]已删除历史会话 #{}".format(params[0])] - else: - reply = ["[bot]没有历史会话 #{}".format(params[0])] - else: - reply = ["[bot]参数错误, 格式: !del <序号>\n可以通过!list查看序号"] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DelCommand, - name="all", - description="删除当前会话的全部历史记录", - usage="!del all", - aliases=[], - privilege=1 -) -class DelAllCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - reply = [] - pkg.openai.session.get_session(session_name).delete_all_history() - reply = ["[bot]已删除所有历史会话"] - return True, reply diff --git a/pkg/qqbot/cmds/session/delhst.py b/pkg/qqbot/cmds/session/delhst.py deleted file mode 100644 index 31791492..00000000 --- a/pkg/qqbot/cmds/session/delhst.py +++ /dev/null @@ -1,50 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="delhst", - description="删除指定会话的所有历史记录", - usage="!delhst <会话名称>\n!delhst all", - aliases=[], - privilege=2 -) -class DelHistoryCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - import pkg.utils.context - params = ctx.params - reply = [] - if len(params) == 0: - reply = [ - "[bot]err:请输入要删除的会话名: group_<群号> 或者 person_, 或使用 !delhst all 删除所有会话的历史记录"] - else: - if params[0] == 'all': - return False, [] - else: - if pkg.utils.context.get_database_manager().delete_all_history(params[0]): - reply = ["[bot]已删除会话 {} 的所有历史记录".format(params[0])] - else: - reply = ["[bot]未找到会话 {} 的历史记录".format(params[0])] - - return True, reply - - -@aamgr.AbstractCommandNode.register( - parent=DelHistoryCommand, - name="all", - description="删除所有会话的全部历史记录", - usage="!delhst all", - aliases=[], - privilege=2 -) -class DelAllHistoryCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.utils.context - reply = [] - pkg.utils.context.get_database_manager().delete_all_session_history() - reply = ["[bot]已删除所有会话的历史记录"] - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/last.py b/pkg/qqbot/cmds/session/last.py deleted file mode 100644 index 93459c44..00000000 --- a/pkg/qqbot/cmds/session/last.py +++ /dev/null @@ -1,29 +0,0 @@ -import datetime - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="last", - description="切换前一次对话", - usage="!last", - aliases=[], - privilege=1 -) -class LastCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - - reply = [] - result = pkg.openai.session.get_session(session_name).last_session() - if result is None: - reply = ["[bot]没有前一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)] - - return True, reply diff --git a/pkg/qqbot/cmds/session/list.py b/pkg/qqbot/cmds/session/list.py deleted file mode 100644 index fb00976d..00000000 --- a/pkg/qqbot/cmds/session/list.py +++ /dev/null @@ -1,65 +0,0 @@ -import datetime -import json - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name='list', - description='列出当前会话的所有历史记录', - usage='!list\n!list [页数]', - aliases=[], - privilege=1 -) -class ListCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - pkg.openai.session.get_session(session_name).persistence() - page = 0 - - if len(params) > 0: - try: - page = int(params[0]) - except ValueError: - pass - - results = pkg.openai.session.get_session(session_name).list_history(page=page) - if len(results) == 0: - reply_str = "[bot]第{}页没有历史会话".format(page) - else: - reply_str = "[bot]历史会话 第{}页:\n".format(page) - current = -1 - for i in range(len(results)): - # 时间(使用create_timestamp转换) 序号 部分内容 - datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) - msg = "" - - msg = json.loads(results[i]['prompt']) - - if len(msg) >= 2: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - msg[0]['content']) - else: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - "无内容") - if results[i]['create_timestamp'] == pkg.openai.session.get_session( - session_name).create_timestamp: - current = i + page * 10 - - reply_str += "\n以上信息倒序排列" - if current != -1: - reply_str += ",当前会话是 #{}\n".format(current) - else: - reply_str += ",当前处于全新会话或不在此页" - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/next.py b/pkg/qqbot/cmds/session/next.py deleted file mode 100644 index 7704acf6..00000000 --- a/pkg/qqbot/cmds/session/next.py +++ /dev/null @@ -1,29 +0,0 @@ -import datetime - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="next", - description="切换后一次对话", - usage="!next", - aliases=[], - privilege=1 -) -class NextCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - reply = [] - - result = pkg.openai.session.get_session(session_name).next_session() - if result is None: - reply = ["[bot]没有后一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/prompt.py b/pkg/qqbot/cmds/session/prompt.py deleted file mode 100644 index adb2e583..00000000 --- a/pkg/qqbot/cmds/session/prompt.py +++ /dev/null @@ -1,31 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="prompt", - description="获取当前会话的前文", - usage="!prompt", - aliases=[], - privilege=1 -) -class PromptCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import pkg.openai.session - session_name = ctx.session_name - params = ctx.params - reply = [] - - msgs = "" - session: list = pkg.openai.session.get_session(session_name).prompt - for msg in session: - if len(params) != 0 and params[0] in ['-all', '-a']: - msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) - elif len(msg['content']) > 30: - msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30]) - else: - msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content']) - reply = ["[bot]当前对话所有内容:\n{}".format(msgs)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/resend.py b/pkg/qqbot/cmds/session/resend.py deleted file mode 100644 index 941afb55..00000000 --- a/pkg/qqbot/cmds/session/resend.py +++ /dev/null @@ -1,33 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="resend", - description="重新获取上一次问题的回复", - usage="!resend", - aliases=[], - privilege=1 -) -class ResendCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - from ....openai import session as openai_session - from ....utils import context - from ....qqbot import message - - session_name = ctx.session_name - reply = [] - - session = openai_session.get_session(session_name) - to_send = session.undo() - - mgr = context.get_qqbot_manager() - - config = context.get_config_manager().data - - reply = message.process_normal_message(to_send, mgr, config, - ctx.launcher_type, ctx.launcher_id, - ctx.sender_id) - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/session/reset.py b/pkg/qqbot/cmds/session/reset.py deleted file mode 100644 index a93f3415..00000000 --- a/pkg/qqbot/cmds/session/reset.py +++ /dev/null @@ -1,35 +0,0 @@ -import tips as tips_custom - -from .. import aamgr -from ....openai import session -from ....utils import context - - -@aamgr.AbstractCommandNode.register( - parent=None, - name='reset', - description='重置当前会话', - usage='!reset', - aliases=[], - privilege=1 -) -class ResetCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - params = ctx.params - session_name = ctx.session_name - - reply = "" - - if len(params) == 0: - session.get_session(session_name).reset(explicit=True) - reply = [tips_custom.command_reset_message] - else: - try: - import pkg.openai.dprompt as dprompt - session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) - reply = [tips_custom.command_reset_name_message+"{}".format(dprompt.mode_inst().get_full_name(params[0]))] - except Exception as e: - reply = ["[bot]会话重置失败:{}".format(e)] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/cconfig.py b/pkg/qqbot/cmds/system/cconfig.py deleted file mode 100644 index 321d68c2..00000000 --- a/pkg/qqbot/cmds/system/cconfig.py +++ /dev/null @@ -1,93 +0,0 @@ -import json - -from .. import aamgr - - -def config_operation(cmd, params): - reply = [] - import pkg.utils.context - # config = pkg.utils.context.get_config() - cfg_mgr = pkg.utils.context.get_config_manager() - - false = False - true = True - - reply_str = "" - if len(params) == 0: - reply = ["[bot]err:请输入!cmd cfg查看使用方法"] - else: - cfg_name = params[0] - if cfg_name == 'all': - reply_str = "[bot]所有配置项:\n\n" - for cfg in cfg_mgr.data.keys(): - if not cfg.startswith('__') and not cfg == 'logging': - # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 - if isinstance(cfg_mgr.data[cfg], str): - reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg]) - elif isinstance(cfg_mgr.data[cfg], dict): - # 不进行unicode转义,并格式化 - reply_str += "{}: {}\n".format(cfg, - json.dumps(cfg_mgr.data[cfg], - ensure_ascii=False, indent=4)) - else: - reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg]) - reply = [reply_str] - else: - cfg_entry_path = cfg_name.split('.') - - try: - if len(params) == 1: # 未指定配置值,返回配置项值 - cfg_entry = cfg_mgr.data[cfg_entry_path[0]] - if len(cfg_entry_path) > 1: - for i in range(1, len(cfg_entry_path)): - cfg_entry = cfg_entry[cfg_entry_path[i]] - - if isinstance(cfg_entry, str): - reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry) - elif isinstance(cfg_entry, dict): - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, - json.dumps(cfg_entry, - ensure_ascii=False, indent=4)) - else: - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry) - reply = [reply_str] - else: - cfg_value = " ".join(params[1:]) - - cfg_value = eval(cfg_value) - - cfg_entry = cfg_mgr.data[cfg_entry_path[0]] - if len(cfg_entry_path) > 1: - for i in range(1, len(cfg_entry_path) - 1): - cfg_entry = cfg_entry[cfg_entry_path[i]] - if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)): - cfg_entry[cfg_entry_path[-1]] = cfg_value - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - else: - reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] - else: - cfg_mgr.data[cfg_entry_path[0]] = cfg_value - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - except KeyError: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - except NameError: - reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)] - except ValueError: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - - return reply - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="cfg", - description="配置项管理", - usage="!cfg <配置项> [配置值]\n!cfg all", - aliases=[], - privilege=2 -) -class CfgCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - return True, config_operation(ctx.command, ctx.params) - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/cmd.py b/pkg/qqbot/cmds/system/cmd.py deleted file mode 100644 index f0a33648..00000000 --- a/pkg/qqbot/cmds/system/cmd.py +++ /dev/null @@ -1,39 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="cmd", - description="显示命令列表", - usage="!cmd\n!cmd <命令名称>", - aliases=[], - privilege=1 -) -class CmdCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - command_list = aamgr.__command_list__ - - reply = [] - - if len(ctx.params) == 0: - reply_str = "[bot]当前所有命令:\n\n" - - # 遍历顶级命令 - for key in command_list: - command = command_list[key] - if command['parent'] is None: - reply_str += "!{} - {}\n".format(key, command['description']) - - reply_str += "\n请使用 !cmd <命令名称> 来查看命令的详细信息" - - reply = [reply_str] - else: - command_name = ctx.params[0] - if command_name in command_list: - reply = [command_list[command_name]['cls'].help()] - else: - reply = ["[bot]命令 {} 不存在".format(command_name)] - - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/help.py b/pkg/qqbot/cmds/system/help.py deleted file mode 100644 index 14027b8b..00000000 --- a/pkg/qqbot/cmds/system/help.py +++ /dev/null @@ -1,24 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="help", - description="显示自定义的帮助信息", - usage="!help", - aliases=[], - privilege=1 -) -class HelpCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import tips - reply = ["[bot] "+tips.help_message + "\n请输入 !cmd 查看命令列表"] - - # 警告config.help_message过时 - import config - if hasattr(config, "help_message"): - reply[0] += "\n\n警告:config.py中的help_message已过时,不再生效,请使用tips.py中的help_message替代" - - return True, reply - \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/reload.py b/pkg/qqbot/cmds/system/reload.py deleted file mode 100644 index 378dcef9..00000000 --- a/pkg/qqbot/cmds/system/reload.py +++ /dev/null @@ -1,25 +0,0 @@ -import threading - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="reload", - description="执行热重载", - usage="!reload", - aliases=[], - privilege=2 -) -class ReloadCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - - import pkg.utils.reloader - def reload_task(): - pkg.utils.reloader.reload_all() - - threading.Thread(target=reload_task, daemon=True).start() - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/update.py b/pkg/qqbot/cmds/system/update.py deleted file mode 100644 index d4cca3f3..00000000 --- a/pkg/qqbot/cmds/system/update.py +++ /dev/null @@ -1,38 +0,0 @@ -import threading -import traceback - -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="update", - description="更新程序", - usage="!update", - aliases=[], - privilege=2 -) -class UpdateCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - import pkg.utils.updater - import pkg.utils.reloader - import pkg.utils.context - - def update_task(): - try: - if pkg.utils.updater.update_all(): - pkg.utils.context.get_qqbot_manager().notify_admin("更新完成, 请手动重启程序。") - else: - pkg.utils.context.get_qqbot_manager().notify_admin("无新版本") - except Exception as e0: - traceback.print_exc() - pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) - return - - threading.Thread(target=update_task, daemon=True).start() - - reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/usage.py b/pkg/qqbot/cmds/system/usage.py deleted file mode 100644 index 15f79b49..00000000 --- a/pkg/qqbot/cmds/system/usage.py +++ /dev/null @@ -1,33 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="usage", - description="获取使用情况", - usage="!usage", - aliases=[], - privilege=1 -) -class UsageCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - import config - import pkg.utils.context - - reply = [] - - reply_str = "[bot]各api-key使用情况:\n\n" - - api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key - for key_name in api_keys: - text_length = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_text_length_of_key(api_keys[key_name]) - image_count = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_image_count_of_key(api_keys[key_name]) - reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), - int(image_count)) - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/version.py b/pkg/qqbot/cmds/system/version.py deleted file mode 100644 index 67bf3ef2..00000000 --- a/pkg/qqbot/cmds/system/version.py +++ /dev/null @@ -1,27 +0,0 @@ -from .. import aamgr - - -@aamgr.AbstractCommandNode.register( - parent=None, - name="version", - description="查看版本信息", - usage="!version", - aliases=[], - privilege=1 -) -class VersionCommand(aamgr.AbstractCommandNode): - @classmethod - def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: - reply = [] - import pkg.utils.updater - - reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info()) - try: - if pkg.utils.updater.is_new_version_available(): - reply_str += "\n有新版本可用,请使用命令 !update 进行更新" - except: - pass - - reply = [reply_str] - - return True, reply \ No newline at end of file diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py deleted file mode 100644 index dba2d204..00000000 --- a/pkg/qqbot/command.py +++ /dev/null @@ -1,49 +0,0 @@ -# 命令处理模块 -import logging - -from ..qqbot.cmds import aamgr as cmdmgr - - -def process_command(session_name: str, text_message: str, mgr, config: dict, - launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: - reply = [] - try: - logging.info( - "[{}]发起命令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - cmd = text_message[1:].strip().split(' ')[0] - - params = text_message[1:].strip().split(' ')[1:] - - # 把!~开头的转换成!cfg - if cmd.startswith('~'): - params = [cmd[1:]] + params - cmd = 'cfg' - - # 包装参数 - context = cmdmgr.Context( - command=cmd, - crt_command=cmd, - params=params, - crt_params=params[:], - session_name=session_name, - text_message=text_message, - launcher_type=launcher_type, - launcher_id=launcher_id, - sender_id=sender_id, - is_admin=is_admin, - privilege=2 if is_admin else 1, # 普通用户1,管理员2 - ) - try: - reply = cmdmgr.execute(context) - except cmdmgr.CommandPrivilegeError as e: - reply = ["{}".format(e)] - - return reply - except Exception as e: - mgr.notify_admin("{}命令执行失败:{}".format(session_name, e)) - logging.exception(e) - reply = ["[bot]err:{}".format(e)] - - return reply diff --git a/pkg/qqbot/filter.py b/pkg/qqbot/filter.py deleted file mode 100644 index c3a58093..00000000 --- a/pkg/qqbot/filter.py +++ /dev/null @@ -1,87 +0,0 @@ -# 敏感词过滤模块 -import re -import requests -import json -import logging - -from ..utils import context - - -class ReplyFilter: - sensitive_words = [] - mask = "*" - mask_word = "" - - # 默认值( 兼容性考虑 ) - baidu_check = False - baidu_api_key = "" - baidu_secret_key = "" - inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规" - - def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): - self.sensitive_words = sensitive_words - self.mask = mask - self.mask_word = mask_word - - config = context.get_config_manager().data - - self.baidu_check = config['baidu_check'] - self.baidu_api_key = config['baidu_api_key'] - self.baidu_secret_key = config['baidu_secret_key'] - self.inappropriate_message_tips = config['inappropriate_message_tips'] - - def is_illegal(self, message: str) -> bool: - processed = self.process(message) - if processed != message: - return True - return False - - def process(self, message: str) -> str: - - # 本地关键词屏蔽 - for word in self.sensitive_words: - match = re.findall(word, message) - if len(match) > 0: - for i in range(len(match)): - if self.mask_word == "": - message = message.replace(match[i], self.mask * len(match[i])) - else: - message = message.replace(match[i], self.mask_word) - - # 百度云审核 - if self.baidu_check: - - # 百度云审核URL - baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \ - str(requests.post("https://aip.baidubce.com/oauth/2.0/token", - params={"grant_type": "client_credentials", - "client_id": self.baidu_api_key, - "client_secret": self.baidu_secret_key}).json().get("access_token")) - - # 百度云审核 - payload = "text=" + message - logging.info("向百度云发送:" + payload) - headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} - - if isinstance(payload, str): - payload = payload.encode('utf-8') - - response = requests.request("POST", baidu_url, headers=headers, data=payload) - response_dict = json.loads(response.text) - - if "error_code" in response_dict: - error_msg = response_dict.get("error_msg") - logging.warning(f"百度云判定出错,错误信息:{error_msg}") - conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}" - else: - conclusion = response_dict["conclusion"] - if conclusion in ("合规"): - logging.info(f"百度云判定结果:{conclusion}") - return message - else: - logging.warning(f"百度云判定结果:{conclusion}") - conclusion = self.inappropriate_message_tips - # 返回百度云审核结果 - return conclusion - - return message diff --git a/pkg/qqbot/ignore.py b/pkg/qqbot/ignore.py deleted file mode 100644 index e1adc777..00000000 --- a/pkg/qqbot/ignore.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from ..utils import context - - -def ignore(msg: str) -> bool: - """检查消息是否应该被忽略""" - config = context.get_config_manager().data - - if 'prefix' in config['ignore_rules']: - for rule in config['ignore_rules']['prefix']: - if msg.startswith(rule): - return True - - if 'regexp' in config['ignore_rules']: - for rule in config['ignore_rules']['regexp']: - if re.search(rule, msg): - return True diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py deleted file mode 100644 index 8cd663ff..00000000 --- a/pkg/qqbot/manager.py +++ /dev/null @@ -1,427 +0,0 @@ -import json -import os -import logging - -from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ - FriendMessage, Image, MessageChain, Plain -import func_timeout - -from ..openai import session as openai_session - -from ..qqbot import filter as qqbot_filter -from ..qqbot import process as processor -from ..utils import context -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models -import tips as tips_custom -from ..qqbot import adapter as msadapter - - -# 检查消息是否符合泛响应匹配机制 -def check_response_rule(group_id:int, text: str): - config = context.get_config_manager().data - - rules = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - rules = config['response_rules'][str(group_id)] - else: - rules = config['response_rules']['default'] - - # 检查前缀匹配 - if 'prefix' in rules: - for rule in rules['prefix']: - if text.startswith(rule): - return True, text.replace(rule, "", 1) - - # 检查正则表达式匹配 - if 'regexp' in rules: - for rule in rules['regexp']: - import re - match = re.match(rule, text) - if match: - return True, text - - return False, "" - - -def response_at(group_id: int): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'at' not in use_response_rule: - return True - - return use_response_rule['at'] - - -def random_responding(group_id): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'random_rate' in use_response_rule: - import random - return random.random() < use_response_rule['random_rate'] - return False - - -# 控制QQ消息输入输出的类 -class QQBotManager: - retry = 3 - - adapter: msadapter.MessageSourceAdapter = None - - bot_account_id: int = 0 - - reply_filter = None - - enable_banlist = False - - enable_private = True - enable_group = True - - ban_person = [] - ban_group = [] - - def __init__(self, first_time_init=True): - config = context.get_config_manager().data - - self.timeout = config['process_message_timeout'] - self.retry = config['retry_times'] - - # 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用 - # 故只在第一次初始化时创建bot对象,重载之后使用原bot对象 - # 因此,bot的配置不支持热重载 - if first_time_init: - logging.debug("Use adapter:" + config['msg_source_adapter']) - if config['msg_source_adapter'] == 'yirimirai': - from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter - - mirai_http_api_config = config['mirai_http_api_config'] - self.bot_account_id = config['mirai_http_api_config']['qq'] - self.adapter = YiriMiraiAdapter(mirai_http_api_config) - elif config['msg_source_adapter'] == 'nakuru': - from pkg.qqbot.sources.nakuru import NakuruProjectAdapter - self.adapter = NakuruProjectAdapter(config['nakuru_config']) - self.bot_account_id = self.adapter.bot_account_id - else: - self.adapter = context.get_qqbot_manager().adapter - self.bot_account_id = context.get_qqbot_manager().bot_account_id - - # 保存 account_id 到审计模块 - from ..utils.center import apigroup - apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id) - - context.set_qqbot_manager(self) - - # 注册诸事件 - # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 - def on_friend_message(event: FriendMessage): - - def friend_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_person_message(event) - - context.get_thread_ctl().submit_user_task( - friend_message_handler, - ) - self.adapter.register_listener( - FriendMessage, - on_friend_message - ) - - def on_stranger_message(event: StrangerMessage): - - def stranger_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_person_message(event) - - context.get_thread_ctl().submit_user_task( - stranger_message_handler, - ) - # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 - if config['msg_source_adapter'] == 'yirimirai': - self.adapter.register_listener( - StrangerMessage, - on_stranger_message - ) - - def on_group_message(event: GroupMessage): - - def group_message_handler(event: GroupMessage): - # 触发事件 - args = { - "launcher_type": "group", - "launcher_id": event.group.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_group_message(event) - - context.get_thread_ctl().submit_user_task( - group_message_handler, - event - ) - self.adapter.register_listener( - GroupMessage, - on_group_message - ) - - def unsubscribe_all(): - """取消所有订阅 - - 用于在热重载流程中卸载所有事件处理器 - """ - self.adapter.unregister_listener( - FriendMessage, - on_friend_message - ) - if config['msg_source_adapter'] == 'yirimirai': - self.adapter.unregister_listener( - StrangerMessage, - on_stranger_message - ) - self.adapter.unregister_listener( - GroupMessage, - on_group_message - ) - - self.unsubscribe_all = unsubscribe_all - - # 加载禁用列表 - if os.path.exists("banlist.py"): - import banlist - self.enable_banlist = banlist.enable - self.ban_person = banlist.person - self.ban_group = banlist.group - logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group)) - - if hasattr(banlist, "enable_private"): - self.enable_private = banlist.enable_private - if hasattr(banlist, "enable_group"): - self.enable_group = banlist.enable_group - - config = context.get_config_manager().data - if os.path.exists("sensitive.json") \ - and config['sensitive_word_filter'] is not None \ - and config['sensitive_word_filter']: - with open("sensitive.json", "r", encoding="utf-8") as f: - sensitive_json = json.load(f) - self.reply_filter = qqbot_filter.ReplyFilter( - sensitive_words=sensitive_json['words'], - mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*', - mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else '' - ) - else: - self.reply_filter = qqbot_filter.ReplyFilter([]) - - def send(self, event, msg, check_quote=True, check_at_sender=True): - config = context.get_config_manager().data - - if check_at_sender and config['at_sender']: - msg.insert( - 0, - Plain(" \n") - ) - - # 当回复的正文中包含换行时,quote可能会自带at,此时就不再单独添加at,只添加换行 - if "\n" not in str(msg[1]) or config['msg_source_adapter'] == 'nakuru': - msg.insert( - 0, - At( - event.sender.id - ) - ) - - self.adapter.reply_message( - event, - msg, - quote_origin=True if config['quote_origin'] and check_quote else False - ) - - # 私聊消息处理 - def on_person_message(self, event: MessageEvent): - reply = '' - - config = context.get_config_manager().data - - if not self.enable_private: - logging.debug("已在banlist.py中禁用所有私聊") - elif event.sender.id == self.bot_account_id: - pass - else: - if Image in event.message_chain: - pass - else: - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - - @func_timeout.func_set_timeout(config['process_message_timeout']) - def time_ctrl_wrapper(): - reply = processor.process_message('person', event.sender.id, str(event.message_chain), - event.message_chain, - event.sender.id) - return reply - - reply = time_ctrl_wrapper() - break - except func_timeout.FunctionTimedOut: - logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i)) - openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - if "person_{}".format(event.sender.id) in processor.processing: - processor.processing.remove('person_{}'.format(event.sender.id)) - failed += 1 - continue - - if failed == self.retry: - openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id))) - reply = [tips_custom.reply_message] - - if reply: - return self.send(event, reply, check_quote=False, check_at_sender=False) - - # 群消息处理 - def on_group_message(self, event: GroupMessage): - reply = '' - - config = context.get_config_manager().data - - def process(text=None) -> str: - replys = "" - if At(self.bot_account_id) in event.message_chain: - event.message_chain.remove(At(self.bot_account_id)) - - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - @func_timeout.func_set_timeout(config['process_message_timeout']) - def time_ctrl_wrapper(): - replys = processor.process_message('group', event.group.id, - str(event.message_chain).strip() if text is None else text, - event.message_chain, - event.sender.id) - return replys - - replys = time_ctrl_wrapper() - break - except func_timeout.FunctionTimedOut: - logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i)) - openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() - if "group_{}".format(event.group.id) in processor.processing: - processor.processing.remove('group_{}'.format(event.group.id)) - failed += 1 - continue - - if failed == self.retry: - openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id))) - replys = [tips_custom.replys_message] - - return replys - - if not self.enable_group: - logging.debug("已在banlist.py中禁用所有群聊") - elif Image in event.message_chain: - pass - else: - if At(self.bot_account_id) in event.message_chain and response_at(event.group.id): - # 直接调用 - reply = process() - else: - check, result = check_response_rule(event.group.id, str(event.message_chain).strip()) - - if check: - reply = process(result.strip()) - # 检查是否随机响应 - elif random_responding(event.group.id): - logging.info("随机响应group_{}消息".format(event.group.id)) - reply = process() - - if reply: - return self.send(event, reply) - - # 通知系统管理员 - def notify_admin(self, message: str): - config = context.get_config_manager().data - if config['admin_qq'] != 0 and config['admin_qq'] != []: - logging.info("通知管理员:{}".format(message)) - if type(config['admin_qq']) == int: - self.adapter.send_message( - "person", - config['admin_qq'], - MessageChain([Plain("[bot]{}".format(message))]) - ) - else: - for adm in config['admin_qq']: - self.adapter.send_message( - "person", - adm, - MessageChain([Plain("[bot]{}".format(message))]) - ) - - def notify_admin_message_chain(self, message): - config = context.get_config_manager().data - if config['admin_qq'] != 0 and config['admin_qq'] != []: - logging.info("通知管理员:{}".format(message)) - if type(config['admin_qq']) == int: - self.adapter.send_message( - "person", - config['admin_qq'], - message - ) - else: - for adm in config['admin_qq']: - self.adapter.send_message( - "person", - adm, - message - ) diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py deleted file mode 100644 index beff6645..00000000 --- a/pkg/qqbot/message.py +++ /dev/null @@ -1,134 +0,0 @@ -# 普通消息处理模块 -import logging - -import openai - -from ..utils import context -from ..openai import session as openai_session - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models -import tips as tips_custom - - -def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: - """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" - config = context.get_config_manager().data - context.get_qqbot_manager().notify_admin(notify_admin) - if config['hide_exce_info_to_user']: - return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] - else: - return [set_reply] - - -def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str, - launcher_id: int, sender_id: int) -> list: - session_name = f"{launcher_type}_{launcher_id}" - logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - session = openai_session.get_session(session_name) - - unexpected_exception_times = 0 - - max_unexpected_exception_times = 3 - - reply = [] - while True: - if unexpected_exception_times >= max_unexpected_exception_times: - reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") - break - try: - prefix = "[GPT]" if config['show_prefix'] else "" - - text, finish_reason, funcs = session.query(text_message) - - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "session": session, - "prefix": prefix, - "response_text": text, - "finish_reason": finish_reason, - "funcs_called": funcs, - } - - event = plugin_host.emit(plugin_models.NormalMessageResponded, **args) - - if event.get_return_value("prefix") is not None: - prefix = event.get_return_value("prefix") - - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = [prefix + text] - - except openai.APIConnectionError as e: - err_msg = str(e) - if err_msg.__contains__('Error communicating with OpenAI'): - reply = handle_exception("{}会话调用API失败:{}\n您的网络无法访问OpenAI接口或网络代理不正常".format(session_name, e), - "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - except openai.RateLimitError as e: - logging.debug(type(e)) - logging.debug(e.error['message']) - - if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'): - # 尝试切换api-key - current_key_name = context.get_openai_manager().key_mgr.get_key_name( - context.get_openai_manager().key_mgr.using_key - ) - context.get_openai_manager().key_mgr.set_current_exceeded() - - # 触发插件事件 - args = { - 'key_name': current_key_name, - 'usage': context.get_openai_manager().audit_mgr - .get_usage(context.get_openai_manager().key_mgr.get_using_key_md5()), - 'exceeded_keys': context.get_openai_manager().key_mgr.exceeded, - } - event = plugin_host.emit(plugin_models.KeyExceeded, **args) - - if not event.is_prevented_default(): - switched, name = context.get_openai_manager().key_mgr.auto_switch() - - if not switched: - reply = handle_exception( - "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format( - current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复") - else: - openai.api_key = context.get_openai_manager().key_mgr.get_using_key() - mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) - reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] - continue - elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'): - # 重试 - unexpected_exception_times += 1 - continue - elif 'message' in e.error and e.error['message']\ - .__contains__('The server had an error while processing your request'): - # 重试 - unexpected_exception_times += 1 - continue - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), - "[bot]err:RateLimitError,请重试或联系作者,或等待修复") - except openai.BadRequestError as e: - if config['auto_reset'] and "This model's maximum context length is" in str(e): - session.reset(persist=True) - reply = [tips_custom.session_auto_reset_message] - else: - reply = handle_exception("{}API调用参数错误:{}\n".format( - session_name, e), "[bot]err:API调用参数错误,请联系管理员,或等待修复") - except openai.APIStatusError as e: - reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用,请重试或联系管理员,或等待修复") - except Exception as e: - logging.exception(e) - reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e)) - break - - return reply diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py deleted file mode 100644 index b5962701..00000000 --- a/pkg/qqbot/process.py +++ /dev/null @@ -1,191 +0,0 @@ -# 此模块提供了消息处理的具体逻辑的接口 -import asyncio -import time -import traceback - -import mirai -import logging - -# 这里不使用动态引入config -# 因为在这里动态引入会卡死程序 -# 而此模块静态引用config与动态引入的表现一致 -# 已弃用,由于超时时间现已动态使用 -# import config as config_init_import - -from ..qqbot import ratelimit -from ..qqbot import command, message -from ..openai import session as openai_session -from ..utils import context - -from ..plugin import host as plugin_host -from ..plugin import models as plugin_models -from ..qqbot import ignore -from ..qqbot import banlist -from ..qqbot import blob -import tips as tips_custom - -processing = [] - - -def is_admin(qq: int) -> bool: - """兼容list和int类型的管理员判断""" - config = context.get_config_manager().data - if type(config['admin_qq']) == list: - return qq in config['admin_qq'] - else: - return qq == config['admin_qq'] - - -def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, - sender_id: int) -> mirai.MessageChain: - global processing - - mgr = context.get_qqbot_manager() - - reply = [] - session_name = "{}_{}".format(launcher_type, launcher_id) - - # 检查发送方是否被禁用 - if banlist.is_banned(launcher_type, launcher_id, sender_id): - logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id)) - return [] - - if ignore.ignore(text_message): - logging.info("根据忽略规则忽略消息: {}".format(text_message)) - return [] - - config = context.get_config_manager().data - - if not config['wait_last_done'] and session_name in processing: - return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) - - # 检查是否被禁言 - if launcher_type == 'group': - is_muted = mgr.adapter.is_muted(launcher_id) - if is_muted: - logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) - return reply - - if config['income_msg_check']: - if mgr.reply_filter.is_illegal(text_message): - return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) - - openai_session.get_session(session_name).acquire_response_lock() - - text_message = text_message.strip() - - - # 为强制消息延迟计时 - start_time = time.time() - - # 处理消息 - try: - - processing.append(session_name) - try: - msg_type = '' - if text_message.startswith('!') or text_message.startswith("!"): # 命令 - msg_type = 'command' - # 触发插件事件 - args = { - 'launcher_type': launcher_type, - 'launcher_id': launcher_id, - 'sender_id': sender_id, - 'command': text_message[1:].strip().split(' ')[0], - 'params': text_message[1:].strip().split(' ')[1:], - 'text_message': text_message, - 'is_admin': is_admin(sender_id), - } - event = plugin_host.emit(plugin_models.PersonCommandSent - if launcher_type == 'person' - else plugin_models.GroupCommandSent, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = command.process_command(session_name, text_message, - mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id)) - - else: # 消息 - msg_type = 'message' - # 限速丢弃检查 - # print(ratelimit.__crt_minute_usage__[session_name]) - if config['rate_limit_strategy'] == "drop": - if ratelimit.is_reach_limit(session_name): - logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - - return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] - - before = time.time() - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "text_message": text_message, - } - event = plugin_host.emit(plugin_models.PersonNormalMessageReceived - if launcher_type == 'person' - else plugin_models.GroupNormalMessageReceived, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = message.process_normal_message(text_message, - mgr, config, launcher_type, launcher_id, sender_id) - - # 限速等待时间 - if config['rate_limit_strategy'] == "wait": - time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) - - ratelimit.add_usage(session_name) - - if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): - if type(reply[0]) == mirai.Plain: - reply[0] = reply[0].text - logging.info( - "回复[{}]文字消息:{}".format(session_name, - reply[0][:min(100, len(reply[0]))] + ( - "..." if len(reply[0]) > 100 else ""))) - if msg_type == 'message': - reply = [mgr.reply_filter.process(reply[0])] - - reply = blob.check_text(reply[0]) - else: - logging.info("回复[{}]消息".format(session_name)) - - finally: - processing.remove(session_name) - finally: - openai_session.get_session(session_name).release_response_lock() - - # 检查延迟时间 - if config['force_delay_range'][1] == 0: - delay_time = 0 - else: - import random - - # 从延迟范围中随机取一个值(浮点) - rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1]) - - spent = time.time() - start_time - - # 如果花费时间小于延迟时间,则延迟 - delay_time = rdm - spent if rdm - spent > 0 else 0 - - # 延迟 - if delay_time > 0: - logging.info("[风控] 强制延迟{:.2f}秒(如需关闭,请到config.py修改force_delay_range字段)".format(delay_time)) - time.sleep(delay_time) - - return mirai.MessageChain(reply) diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py deleted file mode 100644 index 96d289ff..00000000 --- a/pkg/qqbot/ratelimit.py +++ /dev/null @@ -1,89 +0,0 @@ -# 限速相关模块 -import time -import logging -import threading - -from ..utils import context - - -__crt_minute_usage__ = {} -"""当前分钟每个会话的对话次数""" - - -__timer_thr__: threading.Thread = None - - -def get_limitation(session_name: str) -> int: - """获取会话的限制次数""" - config = context.get_config_manager().data - - if session_name in config['rate_limitation']: - return config['rate_limitation'][session_name] - else: - return config['rate_limitation']["default"] - - -def add_usage(session_name: str): - """增加会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - __crt_minute_usage__[session_name] += 1 - else: - __crt_minute_usage__[session_name] = 1 - - -def start_timer(): - """启动定时器""" - global __timer_thr__ - __timer_thr__ = threading.Thread(target=run_timer, daemon=True) - __timer_thr__.start() - - -def run_timer(): - """启动定时器,每分钟清空一次对话次数""" - global __crt_minute_usage__ - global __timer_thr__ - - # 等待直到整分钟 - time.sleep(60 - time.time() % 60) - - while True: - if __timer_thr__ != threading.current_thread(): - break - - logging.debug("清空当前分钟的对话次数") - __crt_minute_usage__ = {} - time.sleep(60) - - -def get_usage(session_name: str) -> int: - """获取会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] - else: - return 0 - - -def get_rest_wait_time(session_name: str, spent: float) -> float: - """获取会话此回合的剩余等待时间""" - global __crt_minute_usage__ - - min_seconds_per_round = 60.0 / get_limitation(session_name) - - if session_name in __crt_minute_usage__: - return max(0, min_seconds_per_round - spent) - else: - return 0 - - -def is_reach_limit(session_name: str) -> bool: - """判断会话是否超过限制""" - global __crt_minute_usage__ - - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] >= get_limitation(session_name) - else: - return False - -start_timer() diff --git a/pkg/utils/__init__.py b/pkg/utils/__init__.py index 5b1c9803..e69de29b 100644 --- a/pkg/utils/__init__.py +++ b/pkg/utils/__init__.py @@ -1 +0,0 @@ -from .threadctl import ThreadCtl \ No newline at end of file diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py new file mode 100644 index 00000000..d17ac62a --- /dev/null +++ b/pkg/utils/announce.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import json +import typing +import os +import base64 + +import pydantic +import requests + +from ..core import app + + +class Announcement(pydantic.BaseModel): + """公告""" + + id: int + + time: str + + timestamp: int + + content: str + + enabled: typing.Optional[bool] = True + + def to_dict(self) -> dict: + return { + "id": self.id, + "time": self.time, + "timestamp": self.timestamp, + "content": self.content, + "enabled": self.enabled + } + + +class AnnouncementManager: + """公告管理器""" + + ap: app.Application = None + + def __init__(self, ap: app.Application): + self.ap = ap + + async def fetch_all( + self + ) -> list[Announcement]: + """获取所有公告""" + resp = requests.get( + url="https://api.github.com/repos/RockChinQ/QChatGPT/contents/res/announcement.json", + proxies=self.ap.proxy_mgr.get_forward_proxies(), + timeout=5 + ) + obj_json = resp.json() + b64_content = obj_json["content"] + # 解码 + content = base64.b64decode(b64_content).decode("utf-8") + + return [Announcement(**item) for item in json.loads(content)] + + async def fetch_saved( + self + ) -> list[Announcement]: + if not os.path.exists("res/announcement_saved.json"): + with open("res/announcement_saved.json", "w", encoding="utf-8") as f: + f.write("[]") + + with open("res/announcement_saved.json", "r", encoding="utf-8") as f: + content = f.read() + + if not content: + content = '[]' + + return [Announcement(**item) for item in json.loads(content)] + + async def write_saved( + self, + content: list[Announcement] + ): + + with open("res/announcement_saved.json", "w", encoding="utf-8") as f: + f.write(json.dumps([ + item.to_dict() for item in content + ], indent=4, ensure_ascii=False)) + + async def fetch_new( + self + ) -> list[Announcement]: + """获取新公告""" + all = await self.fetch_all() + saved = await self.fetch_saved() + + to_show: list[Announcement] = [] + + for item in all: + # 遍历saved检查是否有相同id的公告 + for saved_item in saved: + if saved_item.id == item.id: + break + else: + if item.enabled: + # 没有相同id的公告 + to_show.append(item) + + await self.write_saved(all) + return to_show + + async def show_announcements( + self + ): + """显示公告""" + try: + announcements = await self.fetch_new() + for ann in announcements: + self.ap.logger.info(f'[公告] {ann.time}: {ann.content}') + + if announcements: + + await self.ap.ctr_mgr.main.post_announcement_showed( + ids=[item.id for item in announcements] + ) + except Exception as e: + self.ap.logger.warning(f'获取公告时出错: {e}') diff --git a/pkg/utils/announcement.py b/pkg/utils/announcement.py deleted file mode 100644 index 4bff412d..00000000 --- a/pkg/utils/announcement.py +++ /dev/null @@ -1,68 +0,0 @@ -import base64 -import os -import json - -import requests - - -def read_latest() -> list: - import pkg.utils.network as network - resp = requests.get( - url="https://api.github.com/repos/RockChinQ/QChatGPT/contents/res/announcement.json", - proxies=network.wrapper_proxies() - ) - obj_json = resp.json() - b64_content = obj_json["content"] - # 解码 - content = base64.b64decode(b64_content).decode("utf-8") - return json.loads(content) - - -def read_saved() -> list: - # 已保存的在res/announcement_saved - # 检查是否存在 - if not os.path.exists("res/announcement_saved.json"): - with open("res/announcement_saved.json", "w", encoding="utf-8") as f: - f.write("[]") - - with open("res/announcement_saved.json", "r", encoding="utf-8") as f: - content = f.read() - - return json.loads(content) - - -def write_saved(content: list): - # 已保存的在res/announcement_saved - with open("res/announcement_saved.json", "w", encoding="utf-8") as f: - f.write(json.dumps(content, indent=4, ensure_ascii=False)) - - -def fetch_new() -> list: - latest = read_latest() - saved = read_saved() - - to_show: list = [] - - for item in latest: - # 遍历saved检查是否有相同id的公告 - for saved_item in saved: - if saved_item["id"] == item["id"]: - break - else: - # 没有相同id的公告 - to_show.append(item) - - write_saved(latest) - return to_show - - -if __name__ == '__main__': - - resp = requests.get( - url="https://api.github.com/repos/RockChinQ/QChatGPT/contents/res/announcement.json", - ) - obj_json = resp.json() - b64_content = obj_json["content"] - # 解码 - content = base64.b64decode(b64_content).decode("utf-8") - print(json.dumps(json.loads(content), indent=4, ensure_ascii=False)) diff --git a/pkg/utils/center/apigroup.py b/pkg/utils/center/apigroup.py deleted file mode 100644 index 94812d59..00000000 --- a/pkg/utils/center/apigroup.py +++ /dev/null @@ -1,88 +0,0 @@ -import abc -import uuid -import json -import logging -import threading - -import requests - - -class APIGroup(metaclass=abc.ABCMeta): - """API 组抽象类""" - _basic_info: dict = None - _runtime_info: dict = None - - prefix = None - - def __init__(self, prefix: str): - self.prefix = prefix - - def do( - self, - method: str, - path: str, - data: dict = None, - params: dict = None, - headers: dict = {}, - **kwargs - ): - """执行一个请求""" - def thr_wrapper( - self, - method: str, - path: str, - data: dict = None, - params: dict = None, - headers: dict = {}, - **kwargs - ): - try: - url = self.prefix + path - data = json.dumps(data) - headers['Content-Type'] = 'application/json' - - ret = requests.request( - method, - url, - data=data, - params=params, - headers=headers, - **kwargs - ) - - logging.debug("data: %s", data) - - logging.debug("ret: %s", ret.json()) - except Exception as e: - logging.debug("上报数据失败: %s", e) - - thr = threading.Thread(target=thr_wrapper, args=( - self, - method, - path, - data, - params, - headers, - ), kwargs=kwargs) - thr.start() - - - def gen_rid( - self - ): - """生成一个请求 ID""" - return str(uuid.uuid4()) - - def basic_info( - self - ): - """获取基本信息""" - basic_info = APIGroup._basic_info.copy() - basic_info['rid'] = self.gen_rid() - return basic_info - - def runtime_info( - self - ): - """获取运行时信息""" - return APIGroup._runtime_info diff --git a/pkg/utils/context.py b/pkg/utils/context.py deleted file mode 100644 index e6a2734a..00000000 --- a/pkg/utils/context.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -import threading -from . import threadctl - -from ..database import manager as db_mgr -from ..openai import manager as openai_mgr -from ..qqbot import manager as qqbot_mgr -from ..config import manager as config_mgr -from ..plugin import host as plugin_host -from .center import v2 as center_v2 - - -context = { - 'inst': { - 'database.manager.DatabaseManager': None, - 'openai.manager.OpenAIInteract': None, - 'qqbot.manager.QQBotManager': None, - 'config.manager.ConfigManager': None, - }, - 'pool_ctl': None, - 'logger_handler': None, - 'config': None, - 'plugin_host': None, -} -context_lock = threading.Lock() - -### context耦合度非常高,需要大改 ### -def set_config(inst): - context_lock.acquire() - context['config'] = inst - context_lock.release() - - -def get_config(): - context_lock.acquire() - t = context['config'] - context_lock.release() - return t - - -def set_database_manager(inst: db_mgr.DatabaseManager): - context_lock.acquire() - context['inst']['database.manager.DatabaseManager'] = inst - context_lock.release() - - -def get_database_manager() -> db_mgr.DatabaseManager: - context_lock.acquire() - t = context['inst']['database.manager.DatabaseManager'] - context_lock.release() - return t - - -def set_openai_manager(inst: openai_mgr.OpenAIInteract): - context_lock.acquire() - context['inst']['openai.manager.OpenAIInteract'] = inst - context_lock.release() - - -def get_openai_manager() -> openai_mgr.OpenAIInteract: - context_lock.acquire() - t = context['inst']['openai.manager.OpenAIInteract'] - context_lock.release() - return t - - -def set_qqbot_manager(inst: qqbot_mgr.QQBotManager): - context_lock.acquire() - context['inst']['qqbot.manager.QQBotManager'] = inst - context_lock.release() - - -def get_qqbot_manager() -> qqbot_mgr.QQBotManager: - context_lock.acquire() - t = context['inst']['qqbot.manager.QQBotManager'] - context_lock.release() - return t - - -def set_config_manager(inst: config_mgr.ConfigManager): - context_lock.acquire() - context['inst']['config.manager.ConfigManager'] = inst - context_lock.release() - - -def get_config_manager() -> config_mgr.ConfigManager: - context_lock.acquire() - t = context['inst']['config.manager.ConfigManager'] - context_lock.release() - return t - - -def set_plugin_host(inst: plugin_host.PluginHost): - context_lock.acquire() - context['plugin_host'] = inst - context_lock.release() - - -def get_plugin_host() -> plugin_host.PluginHost: - context_lock.acquire() - t = context['plugin_host'] - context_lock.release() - return t - - -def set_thread_ctl(inst: threadctl.ThreadCtl): - context_lock.acquire() - context['pool_ctl'] = inst - context_lock.release() - - -def get_thread_ctl() -> threadctl.ThreadCtl: - context_lock.acquire() - t: threadctl.ThreadCtl = context['pool_ctl'] - context_lock.release() - return t - - -def set_center_v2_api(inst: center_v2.V2CenterAPI): - context_lock.acquire() - context['center_v2_api'] = inst - context_lock.release() - - -def get_center_v2_api() -> center_v2.V2CenterAPI: - context_lock.acquire() - t: center_v2.V2CenterAPI = context['center_v2_api'] - context_lock.release() - return t \ No newline at end of file diff --git a/pkg/utils/log.py b/pkg/utils/log.py deleted file mode 100644 index 6be28b5b..00000000 --- a/pkg/utils/log.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import time -import logging -import shutil - -from . import context - - -log_file_name = "qchatgpt.log" - - -log_colors_config = { - 'DEBUG': 'green', # cyan white - 'INFO': 'white', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'cyan', -} - - -def init_runtime_log_file(): - """为此次运行生成日志文件 - 格式: qchatgpt-yyyy-MM-dd-HH-mm-ss.log - """ - global log_file_name - - # 检查logs目录是否存在 - if not os.path.exists("logs"): - os.mkdir("logs") - - log_file_name = "logs/qchatgpt-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - - -def reset_logging(): - global log_file_name - - import pkg.utils.context - import colorlog - - if pkg.utils.context.context['logger_handler'] is not None: - logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) - - for handler in logging.getLogger().handlers: - logging.getLogger().removeHandler(handler) - - config_mgr = context.get_config_manager() - - logging_level = logging.INFO if config_mgr is None else config_mgr.data['logging_level'] - - logging.basicConfig(level=logging_level, # 设置日志输出格式 - filename=log_file_name, # log日志输出的文件位置和文件名 - format="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", - # 日志输出的格式 - # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 - ) - sh = logging.StreamHandler() - sh.setLevel(logging_level) - sh.setFormatter(colorlog.ColoredFormatter( - fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " - "%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - log_colors=log_colors_config - )) - logging.getLogger().addHandler(sh) - pkg.utils.context.context['logger_handler'] = sh - return sh diff --git a/pkg/utils/network.py b/pkg/utils/network.py deleted file mode 100644 index a4498854..00000000 --- a/pkg/utils/network.py +++ /dev/null @@ -1,11 +0,0 @@ -from . import context - - -def wrapper_proxies() -> dict: - """获取代理""" - config = context.get_config_manager().data - - return { - "http": config['openai_config']['proxy'], - "https": config['openai_config']['proxy'] - } if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None diff --git a/pkg/utils/pkgmgr.py b/pkg/utils/pkgmgr.py index 741c8f48..ed0d3dbf 100644 --- a/pkg/utils/pkgmgr.py +++ b/pkg/utils/pkgmgr.py @@ -1,39 +1,27 @@ from pip._internal import main as pipmain -from . import log +# from . import log def install(package): pipmain(['install', package]) - log.reset_logging() + # log.reset_logging() def install_upgrade(package): pipmain(['install', '--upgrade', package, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - log.reset_logging() + # log.reset_logging() def run_pip(params: list): pipmain(params) - log.reset_logging() + # log.reset_logging() def install_requirements(file): pipmain(['install', '-r', file, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - log.reset_logging() - - -def ensure_dulwich(): - # 尝试三次 - for i in range(3): - try: - import dulwich - return - except ImportError: - install('dulwich') - - raise ImportError("无法自动安装dulwich库") + # log.reset_logging() if __name__ == "__main__": diff --git a/pkg/utils/proxy.py b/pkg/utils/proxy.py new file mode 100644 index 00000000..7ebd3171 --- /dev/null +++ b/pkg/utils/proxy.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import os +import sys + +from ..core import app + + +class ProxyManager: + ap: app.Application + + forward_proxies: dict[str, str] + + def __init__(self, ap: app.Application): + self.ap = ap + + self.forward_proxies = {} + + async def initialize(self): + self.forward_proxies = { + "http://": os.getenv("HTTP_PROXY") or os.getenv("http_proxy"), + "https://": os.getenv("HTTPS_PROXY") or os.getenv("https_proxy"), + } + + if 'http' in self.ap.system_cfg.data['network-proxies']: + self.forward_proxies['http://'] = self.ap.system_cfg.data['network-proxies']['http'] + if 'https' in self.ap.system_cfg.data['network-proxies']: + self.forward_proxies['https://'] = self.ap.system_cfg.data['network-proxies']['https'] + + def get_forward_proxies(self) -> dict: + return self.forward_proxies.copy() diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py deleted file mode 100644 index eefe33b0..00000000 --- a/pkg/utils/reloader.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -import importlib -import pkgutil -import asyncio - -from . import context -from ..plugin import host as plugin_host - - -def walk(module, prefix='', path_prefix=''): - """遍历并重载所有模块""" - for item in pkgutil.iter_modules(module.__path__): - if item.ispkg: - - walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/') - else: - logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py')) - plugin_host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py' - importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=[''])) - - -def reload_all(notify=True): - # 解除bot的事件注册 - import pkg - context.get_qqbot_manager().unsubscribe_all() - # 执行关闭流程 - logging.info("执行程序关闭流程") - import main - main.stop() - - # 删除所有已注册的命令 - import pkg.qqbot.cmds.aamgr as cmdsmgr - cmdsmgr.__command_list__ = {} - cmdsmgr.__tree_index__ = {} - - # 重载所有模块 - context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded - this_context = context.context - walk(pkg) - importlib.reload(__import__("config-template")) - importlib.reload(__import__('config')) - importlib.reload(__import__('main')) - importlib.reload(__import__('banlist')) - importlib.reload(__import__('tips')) - context.context = this_context - - # 重载插件 - import plugins - walk(plugins) - - # 初始化相关文件 - main.check_file() - - # 执行启动流程 - logging.info("执行程序启动流程") - - context.get_thread_ctl().reload( - admin_pool_num=4, - user_pool_num=8 - ) - - def run_wrapper(): - asyncio.run(main.start_process(False)) - - context.get_thread_ctl().submit_sys_task( - run_wrapper - ) - - logging.info('程序启动完成') - if notify: - context.get_qqbot_manager().notify_admin("重载完成") diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py deleted file mode 100644 index 5be723ed..00000000 --- a/pkg/utils/text2img.py +++ /dev/null @@ -1,208 +0,0 @@ -import logging -import re -import os -import traceback - -from PIL import Image, ImageDraw, ImageFont - -from ..utils import context - - -text_render_font: ImageFont = None - -def initialize(): - global text_render_font - logging.debug("初始化文字转图片模块...") - config = context.get_config_manager().data - - if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体 - use_font = config['font_path'] - try: - - # 检查是否存在 - if not os.path.exists(use_font): - # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" - if not os.path.exists(use_font): - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config['blob_message_strategy'] = "forward" - else: - logging.info("使用Windows自带字体:" + use_font) - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - else: - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config['blob_message_strategy'] = "forward" - else: - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - except: - traceback.print_exc() - logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) - config['blob_message_strategy'] = "forward" - - logging.debug("字体文件加载完成。") - - -def indexNumber(path=''): - """ - 查找字符串中数字所在串中的位置 - :param path:目标字符串 - :return:: : [['1', 16], ['2', 35], ['1', 51]] - """ - kv = [] - nums = [] - beforeDatas = re.findall('[\d]+', path) - for num in beforeDatas: - indexV = [] - times = path.count(num) - if times > 1: - if num not in nums: - indexs = re.finditer(num, path) - for index in indexs: - iV = [] - i = index.span()[0] - iV.append(num) - iV.append(i) - kv.append(iV) - nums.append(num) - else: - index = path.find(num) - indexV.append(num) - indexV.append(index) - kv.append(indexV) - # 根据数字位置排序 - indexSort = [] - resultIndex = [] - for vi in kv: - indexSort.append(vi[1]) - indexSort.sort() - for i in indexSort: - for v in kv: - if i == v[1]: - resultIndex.append(v) - return resultIndex - - -def get_size(file): - # 获取文件大小:KB - size = os.path.getsize(file) - return size / 1024 - - -def get_outfile(infile, outfile): - if outfile: - return outfile - dir, suffix = os.path.splitext(infile) - outfile = '{}-out{}'.format(dir, suffix) - return outfile - - -def compress_image(infile, outfile='', kb=100, step=20, quality=90): - """不改变图片尺寸压缩到指定大小 - :param infile: 压缩源文件 - :param outfile: 压缩文件保存地址 - :param mb: 压缩目标,KB - :param step: 每次调整的压缩比率 - :param quality: 初始压缩比率 - :return: 压缩文件地址,压缩文件大小 - """ - o_size = get_size(infile) - if o_size <= kb: - return infile, o_size - outfile = get_outfile(infile, outfile) - while o_size > kb: - im = Image.open(infile) - im.save(outfile, quality=quality) - if quality - step < 0: - break - quality -= step - o_size = get_size(outfile) - return outfile, get_size(outfile) - - -def text_to_image(text_str: str, save_as="temp.png", width=800): - global text_render_font - - logging.debug("正在将文本转换为图片...") - - text_str = text_str.replace("\t", " ") - - # 分行 - lines = text_str.split('\n') - - # 计算并分割 - final_lines = [] - - text_width = width-80 - - logging.debug("lines: {}, text_width: {}".format(lines, text_width)) - for line in lines: - logging.debug(type(text_render_font)) - # 如果长了就分割 - line_width = text_render_font.getlength(line) - logging.debug("line_width: {}".format(line_width)) - if line_width < text_width: - final_lines.append(line) - continue - else: - rest_text = line - while True: - # 分割最前面的一行 - point = int(len(rest_text) * (text_width / line_width)) - - # 检查断点是否在数字中间 - numbers = indexNumber(rest_text) - - for number in numbers: - if number[1] < point < number[1] + len(number[0]) and number[1] != 0: - point = number[1] - break - - final_lines.append(rest_text[:point]) - rest_text = rest_text[point:] - line_width = text_render_font.getlength(rest_text) - if line_width < text_width: - final_lines.append(rest_text) - break - else: - continue - # 准备画布 - img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) - draw = ImageDraw.Draw(img, mode='RGBA') - - logging.debug("正在绘制图片...") - # 绘制正文 - line_number = 0 - offset_x = 20 - offset_y = 30 - for final_line in final_lines: - draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=text_render_font) - # 遍历此行,检查是否有emoji - idx_in_line = 0 - for ch in final_line: - # if self.is_emoji(ch): - # emoji_img_valid = ensure_emoji(hex(ord(ch))[2:]) - # if emoji_img_valid: # emoji图像可用,绘制到指定位置 - # emoji_image = Image.open("emojis/{}.png".format(hex(ord(ch))[2:]), mode='r').convert('RGBA') - # emoji_image = emoji_image.resize((32, 32)) - - # x, y = emoji_image.size - - # final_emoji_img = Image.new('RGBA', emoji_image.size, (255, 255, 255)) - # final_emoji_img.paste(emoji_image, (0, 0, x, y), emoji_image) - - # img.paste(final_emoji_img, box=(int(offset_x + idx_in_line * 32), offset_y + 35 * line_number)) - - # 检查字符占位宽 - char_code = ord(ch) - if char_code >= 127: - idx_in_line += 1 - else: - idx_in_line += 0.5 - - line_number += 1 - - logging.debug("正在保存图片...") - img.save(save_as) - - return save_as diff --git a/pkg/utils/threadctl.py b/pkg/utils/threadctl.py deleted file mode 100644 index ab764cc3..00000000 --- a/pkg/utils/threadctl.py +++ /dev/null @@ -1,93 +0,0 @@ -import threading -import time -from concurrent.futures import ThreadPoolExecutor - - -class Pool: - """线程池结构""" - pool_num:int = None - ctl:ThreadPoolExecutor = None - task_list:list = None - task_list_lock:threading.Lock = None - monitor_type = True - - def __init__(self, pool_num): - self.pool_num = pool_num - self.ctl = ThreadPoolExecutor(max_workers = self.pool_num) - self.task_list = [] - self.task_list_lock = threading.Lock() - - def __thread_monitor__(self): - while self.monitor_type: - for t in self.task_list: - if not t.done(): - continue - try: - self.task_list.pop(self.task_list.index(t)) - except: - continue - time.sleep(1) - - -class ThreadCtl: - def __init__(self, sys_pool_num, admin_pool_num, user_pool_num): - """线程池控制类 - sys_pool_num:分配系统使用的线程池数量(>=8) - admin_pool_num:用于处理管理员消息的线程池数量(>=1) - user_pool_num:分配用于处理用户消息的线程池的数量(>=1) - """ - if sys_pool_num < 5: - raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num)) - if admin_pool_num < 1: - raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num)) - if user_pool_num < 1: - raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num)) - self.__sys_pool__ = Pool(sys_pool_num) - self.__admin_pool__ = Pool(admin_pool_num) - self.__user_pool__ = Pool(user_pool_num) - self.submit_sys_task(self.__sys_pool__.__thread_monitor__) - self.submit_sys_task(self.__admin_pool__.__thread_monitor__) - self.submit_sys_task(self.__user_pool__.__thread_monitor__) - - def __submit__(self, pool: Pool, fn, /, *args, **kwargs ): - t = pool.ctl.submit(fn, *args, **kwargs) - pool.task_list_lock.acquire() - pool.task_list.append(t) - pool.task_list_lock.release() - return t - - def submit_sys_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__sys_pool__, - fn, *args, **kwargs - ) - - def submit_admin_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__admin_pool__, - fn, *args, **kwargs - ) - - def submit_user_task(self, fn, /, *args, **kwargs): - return self.__submit__( - self.__user_pool__, - fn, *args, **kwargs - ) - - def shutdown(self): - self.__user_pool__.ctl.shutdown(cancel_futures=True) - self.__user_pool__.monitor_type = False - self.__admin_pool__.ctl.shutdown(cancel_futures=True) - self.__admin_pool__.monitor_type = False - self.__sys_pool__.monitor_type = False - self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False) - - def reload(self, admin_pool_num, user_pool_num): - self.__user_pool__.ctl.shutdown(cancel_futures=True) - self.__user_pool__.monitor_type = False - self.__admin_pool__.ctl.shutdown(cancel_futures=True) - self.__admin_pool__.monitor_type = False - self.__admin_pool__ = Pool(admin_pool_num) - self.__user_pool__ = Pool(user_pool_num) - self.submit_sys_task(self.__admin_pool__.__thread_monitor__) - self.submit_sys_task(self.__user_pool__.__thread_monitor__) diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py deleted file mode 100644 index ec6e93a8..00000000 --- a/pkg/utils/updater.py +++ /dev/null @@ -1,287 +0,0 @@ -from __future__ import annotations - -import datetime -import logging -import os.path -import time - -import requests - -from . import constants -from . import network -from . import context - - -def check_dulwich_closure(): - try: - import pkg.utils.pkgmgr - pkg.utils.pkgmgr.ensure_dulwich() - except: - pass - - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") - - -def is_newer(new_tag: str, old_tag: str): - """判断版本是否更新,忽略第四位版本和第一位版本""" - if new_tag == old_tag: - return False - - new_tag = new_tag.split(".") - old_tag = old_tag.split(".") - - # 判断主版本是否相同 - if new_tag[0] != old_tag[0]: - return False - - if len(new_tag) < 4: - return True - - # 合成前三段,判断是否相同 - new_tag = ".".join(new_tag[:3]) - old_tag = ".".join(old_tag[:3]) - - return new_tag != old_tag - - -def get_release_list() -> list: - """获取发行列表""" - rls_list_resp = requests.get( - url="https://api.github.com/repos/RockChinQ/QChatGPT/releases", - proxies=network.wrapper_proxies() - ) - - rls_list = rls_list_resp.json() - - return rls_list - - -def get_current_tag() -> str: - """获取当前tag""" - current_tag = constants.semantic_version - if os.path.exists("current_tag"): - with open("current_tag", "r") as f: - current_tag = f.read() - - return current_tag - - -def compare_version_str(v0: str, v1: str) -> int: - """比较两个版本号""" - - # 删除版本号前的v - if v0.startswith("v"): - v0 = v0[1:] - if v1.startswith("v"): - v1 = v1[1:] - - v0:list = v0.split(".") - v1:list = v1.split(".") - - # 如果两个版本号节数不同,把短的后面用0补齐 - if len(v0) < len(v1): - v0.extend(["0"]*(len(v1)-len(v0))) - elif len(v0) > len(v1): - v1.extend(["0"]*(len(v0)-len(v1))) - - # 从高位向低位比较 - for i in range(len(v0)): - if int(v0[i]) > int(v1[i]): - return 1 - elif int(v0[i]) < int(v1[i]): - return -1 - - return 0 - - -def update_all(cli: bool = False) -> bool: - """检查更新并下载源码""" - start_time = time.time() - - current_tag = get_current_tag() - old_tag = current_tag - - rls_list = get_release_list() - - latest_rls = {} - rls_notes = [] - latest_tag_name = "" - for rls in rls_list: - rls_notes.append(rls['name']) # 使用发行名称作为note - if latest_tag_name == "": - latest_tag_name = rls['tag_name'] - - if rls['tag_name'] == current_tag: - break - - if latest_rls == {}: - latest_rls = rls - if not cli: - logging.info("更新日志: {}".format(rls_notes)) - else: - print("更新日志: {}".format(rls_notes)) - - if latest_rls == {} and not is_newer(latest_tag_name, current_tag): # 没有新版本 - return False - - # 下载最新版本的zip到temp目录 - if not cli: - logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url'])) - else: - print("开始下载最新版本: {}".format(latest_rls['zipball_url'])) - zip_url = latest_rls['zipball_url'] - zip_resp = requests.get( - url=zip_url, - proxies=network.wrapper_proxies() - ) - zip_data = zip_resp.content - - # 检查temp/updater目录 - if not os.path.exists("temp"): - os.mkdir("temp") - if not os.path.exists("temp/updater"): - os.mkdir("temp/updater") - with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f: - f.write(zip_data) - - if not cli: - logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) - else: - print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) - - # 解压zip到temp/updater// - import zipfile - # 检查目标文件夹 - if os.path.exists("temp/updater/{}".format(latest_rls['tag_name'])): - import shutil - shutil.rmtree("temp/updater/{}".format(latest_rls['tag_name'])) - os.mkdir("temp/updater/{}".format(latest_rls['tag_name'])) - with zipfile.ZipFile("temp/updater/{}.zip".format(latest_rls['tag_name']), 'r') as zip_ref: - zip_ref.extractall("temp/updater/{}".format(latest_rls['tag_name'])) - - # 覆盖源码 - source_root = "" - # 找到temp/updater//中的第一个子目录路径 - for root, dirs, files in os.walk("temp/updater/{}".format(latest_rls['tag_name'])): - if root != "temp/updater/{}".format(latest_rls['tag_name']): - source_root = root - break - - # 覆盖源码 - import shutil - for root, dirs, files in os.walk(source_root): - # 覆盖所有子文件子目录 - for file in files: - src = os.path.join(root, file) - dst = src.replace(source_root, ".") - if os.path.exists(dst): - os.remove(dst) - - # 检查目标文件夹是否存在 - if not os.path.exists(os.path.dirname(dst)): - os.makedirs(os.path.dirname(dst)) - # 检查目标文件是否存在 - if not os.path.exists(dst): - # 创建目标文件 - open(dst, "w").close() - - shutil.copy(src, dst) - - # 把current_tag写入文件 - current_tag = latest_rls['tag_name'] - with open("current_tag", "w") as f: - f.write(current_tag) - - context.get_center_v2_api().main.post_update_record( - spent_seconds=int(time.time()-start_time), - infer_reason="update", - old_version=old_tag, - new_version=current_tag, - ) - - # 通知管理员 - if not cli: - import pkg.utils.context - pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看。\n请手动重启程序以使用新版本。".format(current_tag, "\n".join(rls_notes[:-1]))) - else: - print("已更新到最新版本: {}\n更新日志:\n{}\n完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看。请手动重启程序以使用新版本。".format(current_tag, "\n".join(rls_notes[:-1]))) - return True - - -def is_repo(path: str) -> bool: - """检查是否是git仓库""" - check_dulwich_closure() - - from dulwich import porcelain - try: - porcelain.open_repo(path) - return True - except: - return False - - -def get_remote_url(repo_path: str) -> str: - """获取远程仓库地址""" - check_dulwich_closure() - - from dulwich import porcelain - repo = porcelain.open_repo(repo_path) - return str(porcelain.get_remote_repo(repo, "origin")[1]) - - -def get_current_version_info() -> str: - """获取当前版本信息""" - rls_list = get_release_list() - current_tag = get_current_tag() - for rls in rls_list: - if rls['tag_name'] == current_tag: - return rls['name'] + "\n" + rls['body'] - return "未知版本" - - -def is_new_version_available() -> bool: - """检查是否有新版本""" - # 从github获取release列表 - rls_list = get_release_list() - if rls_list is None: - return False - - # 获取当前版本 - current_tag = get_current_tag() - - # 检查是否有新版本 - latest_tag_name = "" - for rls in rls_list: - if latest_tag_name == "": - latest_tag_name = rls['tag_name'] - break - - return is_newer(latest_tag_name, current_tag) - - -def get_rls_notes() -> list: - """获取更新日志""" - # 从github获取release列表 - rls_list = get_release_list() - if rls_list is None: - return None - - # 获取当前版本 - current_tag = get_current_tag() - - # 检查是否有新版本 - rls_notes = [] - for rls in rls_list: - if rls['tag_name'] == current_tag: - break - - rls_notes.append(rls['name']) - - return rls_notes - - -if __name__ == "__main__": - update_all() diff --git a/pkg/utils/version.py b/pkg/utils/version.py new file mode 100644 index 00000000..cd1eb912 --- /dev/null +++ b/pkg/utils/version.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import os +import time + +import requests + +from ..core import app +from . import constants + + +class VersionManager: + + ap: app.Application + + def __init__( + self, + ap: app.Application + ): + self.ap = ap + + async def initialize( + self + ): + pass + + def get_current_version( + self + ) -> str: + current_tag = constants.semantic_version + + return current_tag + + async def update_all(self): + """检查更新并下载源码""" + start_time = time.time() + + current_tag = self.get_current_version() + old_tag = current_tag + + rls_list = await self.get_release_list() + + latest_rls = {} + rls_notes = [] + latest_tag_name = "" + for rls in rls_list: + rls_notes.append(rls['name']) # 使用发行名称作为note + if latest_tag_name == "": + latest_tag_name = rls['tag_name'] + + if rls['tag_name'] == current_tag: + break + + if latest_rls == {}: + latest_rls = rls + self.ap.logger.info("更新日志: {}".format(rls_notes)) + + if latest_rls == {} and not self.is_newer(latest_tag_name, current_tag): # 没有新版本 + return False + + # 下载最新版本的zip到temp目录 + self.ap.logger.info("开始下载最新版本: {}".format(latest_rls['zipball_url'])) + + zip_url = latest_rls['zipball_url'] + zip_resp = requests.get( + url=zip_url, + proxies=self.ap.proxy_mgr.get_forward_proxies() + ) + zip_data = zip_resp.content + + # 检查temp/updater目录 + if not os.path.exists("temp"): + os.mkdir("temp") + if not os.path.exists("temp/updater"): + os.mkdir("temp/updater") + with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f: + f.write(zip_data) + + self.ap.logger.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) + + # 解压zip到temp/updater// + import zipfile + # 检查目标文件夹 + if os.path.exists("temp/updater/{}".format(latest_rls['tag_name'])): + import shutil + shutil.rmtree("temp/updater/{}".format(latest_rls['tag_name'])) + os.mkdir("temp/updater/{}".format(latest_rls['tag_name'])) + with zipfile.ZipFile("temp/updater/{}.zip".format(latest_rls['tag_name']), 'r') as zip_ref: + zip_ref.extractall("temp/updater/{}".format(latest_rls['tag_name'])) + + # 覆盖源码 + source_root = "" + # 找到temp/updater//中的第一个子目录路径 + for root, dirs, files in os.walk("temp/updater/{}".format(latest_rls['tag_name'])): + if root != "temp/updater/{}".format(latest_rls['tag_name']): + source_root = root + break + + # 覆盖源码 + import shutil + for root, dirs, files in os.walk(source_root): + # 覆盖所有子文件子目录 + for file in files: + src = os.path.join(root, file) + dst = src.replace(source_root, ".") + if os.path.exists(dst): + os.remove(dst) + + # 检查目标文件夹是否存在 + if not os.path.exists(os.path.dirname(dst)): + os.makedirs(os.path.dirname(dst)) + # 检查目标文件是否存在 + if not os.path.exists(dst): + # 创建目标文件 + open(dst, "w").close() + + shutil.copy(src, dst) + + # 把current_tag写入文件 + current_tag = latest_rls['tag_name'] + with open("current_tag", "w") as f: + f.write(current_tag) + + await self.ap.ctr_mgr.main.post_update_record( + spent_seconds=int(time.time()-start_time), + infer_reason="update", + old_version=old_tag, + new_version=current_tag, + ) + + async def is_new_version_available(self) -> bool: + """检查是否有新版本""" + # 从github获取release列表 + rls_list = await self.get_release_list() + if rls_list is None: + return False + + # 获取当前版本 + current_tag = self.get_current_version() + + # 检查是否有新版本 + latest_tag_name = "" + for rls in rls_list: + if latest_tag_name == "": + latest_tag_name = rls['tag_name'] + break + + return self.is_newer(latest_tag_name, current_tag) + + + def is_newer(self, new_tag: str, old_tag: str): + """判断版本是否更新,忽略第四位版本和第一位版本""" + if new_tag == old_tag: + return False + + new_tag = new_tag.split(".") + old_tag = old_tag.split(".") + + # 判断主版本是否相同 + if new_tag[0] != old_tag[0]: + return False + + if len(new_tag) < 4: + return True + + # 合成前三段,判断是否相同 + new_tag = ".".join(new_tag[:3]) + old_tag = ".".join(old_tag[:3]) + + return new_tag != old_tag + + + def compare_version_str(v0: str, v1: str) -> int: + """比较两个版本号""" + + # 删除版本号前的v + if v0.startswith("v"): + v0 = v0[1:] + if v1.startswith("v"): + v1 = v1[1:] + + v0:list = v0.split(".") + v1:list = v1.split(".") + + # 如果两个版本号节数不同,把短的后面用0补齐 + if len(v0) < len(v1): + v0.extend(["0"]*(len(v1)-len(v0))) + elif len(v0) > len(v1): + v1.extend(["0"]*(len(v0)-len(v1))) + + # 从高位向低位比较 + for i in range(len(v0)): + if int(v0[i]) > int(v1[i]): + return 1 + elif int(v0[i]) < int(v1[i]): + return -1 + + return 0 + + async def show_version_update( + self + ): + try: + + if await self.ap.ver_mgr.is_new_version_available(): + self.ap.logger.info("有新版本可用,请使用 !update 命令更新") + + except Exception as e: + self.ap.logger.warning(f"检查版本更新时出错: {e}") diff --git a/requirements.txt b/requirements.txt index c3e29401..9ddd53fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ requests openai -dulwich~=0.21.6 colorlog~=6.6.0 yiri-mirai-rc +aiocqhttp +qq-botpy websockets urllib3 -func_timeout~=4.3.5 Pillow nakuru-project-idk CallingGPT tiktoken PyYaml -aiohttp \ No newline at end of file +aiohttp +pydantic +aioconsole \ No newline at end of file diff --git a/res/QChatGPT-1211.png b/res/QChatGPT-1211.png deleted file mode 100644 index 09c40c92..00000000 Binary files a/res/QChatGPT-1211.png and /dev/null differ diff --git a/res/announcement b/res/announcement deleted file mode 100644 index 9d70c654..00000000 --- a/res/announcement +++ /dev/null @@ -1 +0,0 @@ -2023/3/31 21:35 【插件兼容性问题】若您使用了revLibs插件,并将主程序升级到了v2.3.0,请立即使用管理员账号向机器人账号发送!plugin update命令更新逆向库插件,以解决由于情景预设重构引起的兼容性问题。 diff --git a/res/docs/docker_deploy.md b/res/docs/docker_deploy.md deleted file mode 100644 index fb76fde8..00000000 --- a/res/docs/docker_deploy.md +++ /dev/null @@ -1,97 +0,0 @@ - -> [!WARNING] -> 此文档已过时,请查看[QChatGPT 容器化部署指南](docker_deployment.md) - -## 操作步骤 - -### 1.安装docker和docker compose - -[各种设备的安装Docker方法](https://yeasy.gitbook.io/docker_practice/install) - -[安装Compose方法](https://yeasy.gitbook.io/docker_practice/compose) - -> `Docker Desktop for Mac/Windows` 自带 `docker-compose` 二进制文件,安装 Docker 之后可以直接使用。 -> -> 可以选择很多下载方法,反正只要安装了就可以了 - -### 2. 登录qq(下面所有步骤建议在项目文件夹下操作) - -#### 2.1 输入指令 - -``` -docker run -d -it --name mcl --network host -v ${PWD}/qq/plugins:/app/plugins -v ${PWD}/qq/config:/app/config -v ${PWD}/qq/data:/app/data -v ${PWD}/qq/bots:/app/bots --restart unless-stopped kagurazakanyaa/mcl:latest -``` - -这里使用了[KagurazakaNyaa/mirai-console-loader-docker](https://github.com/KagurazakaNyaa/mirai-console-loader-docker)的镜像 - -#### 2.2 进入容器 - -``` -docker ps -``` -在输出中查看容器的ID,例如: -```sh -CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES -bce1e5568f46 kagurazakanyaa/mcl "./mcl -u" 10 minutes ago Up 10 minutes 0.0.0.0:8080->8080/tcp, :::8080->8080/tcp admiring_mendeleev -``` -查看`IMAGE`名为`kagurazakanyaa/mcl`的容器的`CONTAINER ID`,在这里是`bce1e5568f46`,于是使用以下命令将其切到前台: -``` -docker attach bce1e5568f46 -``` -如需将其切到后台运行,请使用组合键`Ctrl+P+Q` - -#### 2.3 编写配置文件 - -- 在` /qq/config/net.mamoe.mirai-api-http` 文件夹中找到`setting.yml`,这是`mirai-api-http`的配置文件 - - 将这个文件的内容修改为: - -``` -adapters: - - ws -debug: true -enableVerify: true -verifyKey: yirimirai -singleMode: false -cacheSize: 4096 -adapterSettings: - ws: - host: localhost - port: 8080 - reservedSyncId: -1 -``` - -`verifyKey`要求与`bot`的`config.py`中的`verifyKey`相同 - - `port`: 8080要和2.4 config.py配置里面的端口号相同 - -#### 2.4 登录 - -#### 在mirai上登录QQ - -``` -login <机器人QQ号> <机器人QQ密码> -``` - -> 具体见[此教程](https://yiri-mirai.wybxc.cc/tutorials/01/configuration#4-登录-qq) - -#### 配置自动登录(可选) - -当机器人账号登录成功以后,执行 - -``` -autologin add <机器人QQ号> <机器人密码> -autologin setConfig <机器人QQ号> protocol ANDROID_PAD -``` - -> 出现`无法登录`报错时候[无法登录的临时处理方案](https://mirai.mamoe.net/topic/223/无法登录的临时处理方案) - -**完成后, `Ctrl+P+Q`退出(不会关掉容器,容器还会运行)** - -### 3. 部署QChatGPT - -配置好config.py,保存到当前目录下,运行下面的 - -``` - docker run -it -d --name QChatGPT --network host -v ${PWD}/config.py:/QChatGPT/config.py -v ${PWD}/banlist.py:/QChatGPT/banlist.py -v ${PWD}/sensitive.json:/QChatGPT/sensitive.json mikumifa/qchatgpt-docker -``` - diff --git a/res/docs/docker_deployment.md b/res/docs/docker_deployment.md deleted file mode 100644 index 13420dfa..00000000 --- a/res/docs/docker_deployment.md +++ /dev/null @@ -1,64 +0,0 @@ -# QChatGPT 容器化部署指南 - -> [!WARNING] -> 请您确保您**确实**需要 Docker 部署,您**必须**具有以下能力: -> - 了解 `Docker` 和 `Docker Compose` 的使用 -> - 了解容器间网络通信配置方式 -> - 了解容器文件挂载机制 -> - 了解容器调试操作 -> - 动手能力强、资料查找能力强 -> -> 若您不完全具有以上能力,请勿使用 Docker 部署,由于误操作导致的配置不正确,我们将不会解答您的问题并不负任何责任。 -> **非常不建议**您在除 Linux 之外的系统上使用 Docker 进行部署。 - -## 概览 - -QChatGPT 主程序需要连接`QQ登录框架`以与QQ通信,您可以选择 [Mirai](https://github.com/mamoe/mirai)(还需要配置mirai-api-http,请查看此仓库README中手动部署部分) 或 [go-cqhttp](https://github.com/Mrs4s/go-cqhttp),我们仅发布 QChatGPT主程序 的镜像,您需要自行配置QQ登录框架(可以参考[README.md](https://github.com/RockChinQ/QChatGPT#-%E9%85%8D%E7%BD%AEqq%E7%99%BB%E5%BD%95%E6%A1%86%E6%9E%B6)中的教程,或自行寻找其镜像)并在 QChatGPT 的配置文件中设置连接地址。 - -> [!NOTE] -> 请先确保 Docker 和 Docker Compose 已安装 - -## 准备文件 - -> QChatGPT 目前暂不可以在没有配置模板文件的情况下自动生成文件,您需要按照以下步骤手动创建需要挂载的文件。 -> 如无特殊说明,模板文件均在此仓库中。 - -> 如果您不想挨个创建,也可以直接clone本仓库到本地,执行`python main.py`后即可自动根据模板生成所需文件。 - -现在请在一个空目录创建以下文件或目录: - -### 📄`config.py` - -复制根目录的`config-template.py`所有内容,创建`config.py`并根据其中注释进行修改。 - -### 📄`banlist.py` - -复制`res/templates/banlist-template.py`所有内容,创建`banlist.py`,这是黑名单配置文件,根据需要修改。 - -### 📄`cmdpriv.json` - -复制`res/templates/cmdpriv-template.json`所有内容,创建`cmdpriv.json`,这是各命令的权限配置文件,根据需要修改。 - -### 📄`sensitive.json` - -复制`res/templates/sensitive-template.json`所有内容,创建`sensitive.json`,这是敏感词配置,根据需要修改。 - -### 📄`tips.py` - -复制`tips-custom-template.py`所有内容,创建`tips.py`,这是部分提示语的配置,根据需要修改。 - -## 运行 - -已预先准备好`docker-compose.yaml`,您需要根据您的网络配置进行适当修改,使容器内的 QChatGPT 程序可以正常与 Mirai 或 go-cqhttp 通信。 - -将`docker-compose.yaml`复制到本目录,根据网络环境进行配置,并执行: - -```bash -docker compose up -``` - -若无报错即配置完成,您可以Ctrl+C关闭后使用`docker compose up -d`将其置于后台运行 - -## 注意 - -- 安装的插件都会保存在`plugins`(映射到本目录`plugins`),安装插件时可能会自动安装相应的依赖,此时若`重新创建`容器,已安装的插件将被加载,但所需的增量依赖并未安装,会导致引入问题。您可以删除插件目录后重启,再次安装插件,以便程序可以自动安装插件所需依赖。 \ No newline at end of file diff --git a/res/plugin_hello_group.jpg b/res/plugin_hello_group.jpg deleted file mode 100644 index 74428ff4..00000000 Binary files a/res/plugin_hello_group.jpg and /dev/null differ diff --git a/res/plugin_hello_person.png b/res/plugin_hello_person.png deleted file mode 100644 index 2a2514ce..00000000 Binary files a/res/plugin_hello_person.png and /dev/null differ diff --git a/res/screenshots/group_gpt3.5.png b/res/screenshots/group_gpt3.5.png deleted file mode 100644 index 2b95eca0..00000000 Binary files a/res/screenshots/group_gpt3.5.png and /dev/null differ diff --git a/res/screenshots/person_gpt3.5.png b/res/screenshots/person_gpt3.5.png deleted file mode 100644 index d21d3bb4..00000000 Binary files a/res/screenshots/person_gpt3.5.png and /dev/null differ diff --git a/res/screenshots/person_newbing.png b/res/screenshots/person_newbing.png deleted file mode 100644 index f8fddd53..00000000 Binary files a/res/screenshots/person_newbing.png and /dev/null differ diff --git a/res/screenshots/webwlkr_plugin.png b/res/screenshots/webwlkr_plugin.png deleted file mode 100644 index ba1778f5..00000000 Binary files a/res/screenshots/webwlkr_plugin.png and /dev/null differ diff --git a/res/scripts/generate_cmdpriv_template.py b/res/scripts/generate_cmdpriv_template.py deleted file mode 100644 index f76f3c24..00000000 --- a/res/scripts/generate_cmdpriv_template.py +++ /dev/null @@ -1,17 +0,0 @@ -import pkg.qqbot.cmds.aamgr as cmdsmgr -import json - -# 执行命令模块的注册 -cmdsmgr.register_all() - -# 生成限权文件模板 -template: dict[str, int] = { - "comment": "以下为命令权限,请设置到cmdpriv.json中。关于此功能的说明,请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%91%BD%E4%BB%A4%E6%9D%83%E9%99%90%E6%8E%A7%E5%88%B6", -} - -for key in cmdsmgr.__command_list__: - template[key] = cmdsmgr.__command_list__[key]['privilege'] - -# 写入cmdpriv-template.json -with open('res/templates/cmdpriv-template.json', 'w') as f: - f.write(json.dumps(template, indent=4, ensure_ascii=False)) \ No newline at end of file diff --git a/res/scripts/generate_override_all.py b/res/scripts/generate_override_all.py deleted file mode 100644 index 69674c38..00000000 --- a/res/scripts/generate_override_all.py +++ /dev/null @@ -1,23 +0,0 @@ -# 使用config-template生成override.json的字段全集模板文件override-all.json -# 关于override.json机制,请参考:https://github.com/RockChinQ/QChatGPT/pull/271 -import json -import importlib - - -template = importlib.import_module("config-template") -output_json = { - "comment": "这是override.json支持的字段全集, 关于override.json机制, 请查看https://github.com/RockChinQ/QChatGPT/pull/271" -} - - -for k, v in template.__dict__.items(): - if k.startswith("__"): - continue - # 如果是module - if type(v) == type(template): - continue - print(k, v, type(v)) - output_json[k] = v - -with open("override-all.json", "w", encoding="utf-8") as f: - json.dump(output_json, f, indent=4, ensure_ascii=False) diff --git a/res/templates/banlist-template.py b/res/templates/banlist-template.py deleted file mode 100644 index dcaf375e..00000000 --- a/res/templates/banlist-template.py +++ /dev/null @@ -1,30 +0,0 @@ -# 是否处理群聊消息 -# 为False时忽略所有群聊消息 -# 优先级高于下方禁用列表 -enable_group = True - -# 是否处理私聊消息 -# 为False时忽略所有私聊消息 -# 优先级高于下方禁用列表 -enable_private = True - -# 是否启用禁用列表 -enable = True - -# 禁用规则(黑名单) -# person为个人,其中的QQ号会被禁止与机器人进行私聊或群聊交互 -# 示例: person = [2854196310, 1234567890, 9876543210] -# group为群组,其中的群号会被禁止与机器人进行交互 -# 示例: group = [123456789, 987654321, 1234567890] -# -# 支持正则表达式,字符串都将被识别为正则表达式,例如: -# person = [12345678, 87654321, "2854.*"] -# group = [123456789, 987654321, "1234.*"] -# 若要排除某个QQ号或群号(即允许使用),可以在前面加上"!",例如: -# person = ["!1234567890"] -# group = ["!987654321"] -# 排除规则优先级高于包含规则,即如果同时存在包含规则和排除规则,排除规则将生效,例如: -# person = ["1234.*", "!1234567890"] -# 那么1234567890将不会被禁用,而其他以1234开头的QQ号都会被禁用 -person = [2854196310] # 2854196310是Q群管家机器人的QQ号,默认屏蔽以免出现循环 -group = [204785790, 691226829] # 本项目交流群的群号,默认屏蔽,避免在交流群测试机器人 diff --git a/res/templates/cmdpriv-template.json b/res/templates/cmdpriv-template.json deleted file mode 100644 index 33603878..00000000 --- a/res/templates/cmdpriv-template.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "comment": "以下为命令权限,请设置到cmdpriv.json中。关于此功能的说明,请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%91%BD%E4%BB%A4%E6%9D%83%E9%99%90%E6%8E%A7%E5%88%B6", - "draw": 1, - "func": 1, - "plugin": 1, - "plugin.get": 2, - "plugin.update": 2, - "plugin.del": 2, - "plugin.off": 2, - "plugin.on": 2, - "default": 1, - "default.set": 2, - "del": 1, - "del.all": 1, - "delhst": 2, - "delhst.all": 2, - "last": 1, - "list": 1, - "next": 1, - "prompt": 1, - "resend": 1, - "reset": 1, - "cfg": 2, - "cmd": 1, - "help": 1, - "reload": 2, - "update": 2, - "usage": 1, - "version": 1 -} \ No newline at end of file diff --git a/res/webwlkr-demo.gif b/res/webwlkr-demo.gif deleted file mode 100644 index 86db2848..00000000 Binary files a/res/webwlkr-demo.gif and /dev/null differ diff --git "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150511.png" "b/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150511.png" deleted file mode 100644 index f51bd0d1..00000000 Binary files "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150511.png" and /dev/null differ diff --git "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150949.png" "b/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150949.png" deleted file mode 100644 index d4506578..00000000 Binary files "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-08 150949.png" and /dev/null differ diff --git "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-29 194948.png" "b/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-29 194948.png" deleted file mode 100644 index fe8b2145..00000000 Binary files "a/res/\345\261\217\345\271\225\346\210\252\345\233\276 2022-12-29 194948.png" and /dev/null differ diff --git a/templates/__init__.py b/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/templates/command.json b/templates/command.json new file mode 100644 index 00000000..55360fc6 --- /dev/null +++ b/templates/command.json @@ -0,0 +1,3 @@ +{ + "privilege": {} +} \ No newline at end of file diff --git a/templates/pipeline.json b/templates/pipeline.json new file mode 100644 index 00000000..7284fdaf --- /dev/null +++ b/templates/pipeline.json @@ -0,0 +1,36 @@ +{ + "access-control":{ + "mode": "blacklist", + "blacklist": [], + "whitelist": [] + }, + "respond-rules": { + "default": { + "at": true, + "prefix": [ + "/ai", "!ai", "!ai", "ai" + ], + "regexp": [], + "random": 0.0 + } + }, + "income-msg-check": true, + "ignore-rules": { + "prefix": ["/"], + "regexp": [] + }, + "check-sensitive-words": true, + "baidu-cloud-examine": { + "enable": false, + "api-key": "", + "api-secret": "" + }, + "submit-messages-tokens": 3072, + "rate-limit": { + "strategy": "drop", + "algo": "fixwin", + "fixwin": { + "default": 60 + } + } +} \ No newline at end of file diff --git a/templates/platform.json b/templates/platform.json new file mode 100644 index 00000000..6b4de843 --- /dev/null +++ b/templates/platform.json @@ -0,0 +1,46 @@ +{ + "platform-adapters": [ + { + "adapter": "yiri-mirai", + "enable": false, + "host": "127.0.0.1", + "port": 8080, + "verifyKey": "yirimirai", + "qq": 123456789 + }, + { + "adapter": "nakuru", + "enable": false, + "host": "127.0.0.1", + "ws_port": 8080, + "http_port": 5700, + "token": "" + }, + { + "adapter": "aiocqhttp", + "enable": false, + "host": "127.0.0.1", + "port": 8080 + }, + { + "adapter": "qq-botpy", + "enable": false, + "appid": "", + "secret": "", + "intents": [ + "public_guild_messages", + "direct_message" + ] + } + ], + "track-function-calls": true, + "quote-origin": false, + "at-sender": false, + "force-delay": [0, 0], + "long-text-process": { + "threshold": 256, + "strategy": "forward", + "font-path": "" + }, + "hide-exception-info": true +} \ No newline at end of file diff --git a/templates/plugin-settings.json b/templates/plugin-settings.json new file mode 100644 index 00000000..1d807ed1 --- /dev/null +++ b/templates/plugin-settings.json @@ -0,0 +1,3 @@ +{ + "plugins": [] +} \ No newline at end of file diff --git a/templates/provider.json b/templates/provider.json new file mode 100644 index 00000000..23360277 --- /dev/null +++ b/templates/provider.json @@ -0,0 +1,17 @@ +{ + "enable-chat": true, + "openai-config": { + "api-keys": [ + "sk-1234567890" + ], + "base_url": "https://api.openai.com/v1", + "chat-completions-params": { + "model": "gpt-3.5-turbo" + }, + "request-timeout": 120 + }, + "prompt-mode": "normal", + "prompt": { + "default": "如果用户之后想获取帮助,请你说”输入!help获取帮助“。" + } +} \ No newline at end of file diff --git a/scenario/default-template.json b/templates/scenario-template.json similarity index 100% rename from scenario/default-template.json rename to templates/scenario-template.json diff --git a/res/templates/sensitive-template.json b/templates/sensitive-words.json similarity index 100% rename from res/templates/sensitive-template.json rename to templates/sensitive-words.json diff --git a/templates/system.json b/templates/system.json new file mode 100644 index 00000000..72d29b98 --- /dev/null +++ b/templates/system.json @@ -0,0 +1,14 @@ +{ + "admin-sessions": [], + "network-proxies": { + "http": null, + "https": null + }, + "report-usage": true, + "logging-level": "info", + "session-concurrency": { + "default": 1 + }, + "pipeline-concurrency": 20, + "help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接:https://q.rkcn.top" +} \ No newline at end of file diff --git a/tests/bs_test/bs_test.py b/tests/bs_test/bs_test.py deleted file mode 100644 index 8a8e7eac..00000000 --- a/tests/bs_test/bs_test.py +++ /dev/null @@ -1,42 +0,0 @@ - -import requests -from bs4 import BeautifulSoup -import os -import random -import sys - - -user_agents = [ - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36', - 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0', - 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0' -] - -r = requests.get( - sys.argv[1], - headers={ - "User-Agent": random.choice(user_agents) - } -) -soup = BeautifulSoup(r.text, 'html.parser') -# print(soup.get_text()) - -raw = soup.get_text() - -import re - -# strip每一行 -# raw = '\n'.join([line.strip() for line in raw.split('\n')]) - -# # 删除所有空行或只有空格的行 -# raw = re.sub(r'\n\s*\n', '\n', raw) - - -print(raw) \ No newline at end of file diff --git a/tests/compatibility_tests/models_and_interfaces.py b/tests/compatibility_tests/models_and_interfaces.py deleted file mode 100644 index 1ace18d4..00000000 --- a/tests/compatibility_tests/models_and_interfaces.py +++ /dev/null @@ -1,46 +0,0 @@ -import openai -import time - -# 测试completion api -models = [ - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-0301', - 'text-davinci-003', - 'text-davinci-002', - 'code-davinci-002', - 'code-cushman-001', - 'text-curie-001', - 'text-babbage-001', - 'text-ada-001', -] - -openai.api_key = "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" - -for model in models: - print('Testing model: ', model) - - # completion api - try: - response = openai.Completion.create( - model=model, - prompt="Say this is a test", - max_tokens=7, - temperature=0 - ) - print(' completion api: ', response['choices'][0]['text'].strip()) - except Exception as e: - print(' completion api err: ', e) - - # chat completion api - try: - completion = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "user", "content": "Hello!"} - ] - ) - print(" chat api: ",completion.choices[0].message['content'].strip()) - except Exception as e: - print(' chat api err: ', e) - - time.sleep(60) diff --git a/tests/gpt3_test.py b/tests/gpt3_test.py deleted file mode 100644 index 0b586a7a..00000000 --- a/tests/gpt3_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import openai - -openai.api_key = "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" - -resp = openai.Completion.create( - prompt="user:你好,今天天气怎么样?\nbot:", - model="text-davinci-003", - temperature=0.9, # 数值越低得到的回答越理性,取值范围[0, 1] - top_p=1, # 生成的文本的文本与要求的符合度, 取值范围[0, 1] - frequency_penalty=0.2, - presence_penalty=1.0, -) - -print(resp) \ No newline at end of file diff --git a/tests/identifier_test/host_identifier.py b/tests/identifier_test/host_identifier.py deleted file mode 100644 index 64834931..00000000 --- a/tests/identifier_test/host_identifier.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import uuid -import json - -# 向 ~/.qchatgpt 写入一个 标识符 - -if not os.path.exists(os.path.expanduser('~/.qchatgpt')): - os.mkdir(os.path.expanduser('~/.qchatgpt')) - -identifier = { - "host_id": "host_"+str(uuid.uuid4()), -} - -if not os.path.exists(os.path.expanduser('~/.qchatgpt/host.json')): - print('create ~/.qchatgpt/host.json') - with open(os.path.expanduser('~/.qchatgpt/host.json'), 'w') as f: - json.dump(identifier, f) -else: - print('load ~/.qchatgpt/host.json') - with open(os.path.expanduser('~/.qchatgpt/host.json'), 'r') as f: - identifier = json.load(f) - -print(identifier) - -instance_id = { - "host_id": identifier['host_id'], - "instance_id": "instance_"+str(uuid.uuid4()), -} - -# 实例 id -if os.path.exists("res/instance_id.json"): - with open("res/instance_id.json", 'r') as f: - instance_id = json.load(f) - - if instance_id['host_id'] != identifier['host_id']: - os.remove("res/instance_id.json") - -if not os.path.exists("res/instance_id.json"): - print('create res/instance_id.json') - with open("res/instance_id.json", 'w') as f: - json.dump(instance_id, f) - -print(instance_id) \ No newline at end of file diff --git a/tests/plugin_examples/__init__.py b/tests/plugin_examples/__init__.py deleted file mode 100644 index b063f0ca..00000000 --- a/tests/plugin_examples/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# 插件示例 -# 将此目录下的目录放入plugins目录即可使用 -# 每个示例插件的功能请查看其包内的__init__.py或README.md diff --git a/tests/plugin_examples/auto_approval/main.py b/tests/plugin_examples/auto_approval/main.py deleted file mode 100644 index ba9e924e..00000000 --- a/tests/plugin_examples/auto_approval/main.py +++ /dev/null @@ -1,44 +0,0 @@ -from mirai import Mirai - -import pkg.qqbot.manager -from pkg.plugin.models import * -from pkg.plugin.host import PluginHost - -from mirai.models import MemberJoinRequestEvent - -""" -加群自动审批 -""" - -__group_id__ = 1025599757 -__application_contains__ = ['github', 'gitee', 'Github', 'Gitee', 'GitHub'] - - -# 注册插件 -@register(name="加群审批", description="自动审批加群申请", version="0.1", author="RockChinQ") -class AutoApproval(Plugin): - - bot: Mirai = None - - # 插件加载时触发 - def __init__(self, plugin_host: PluginHost): - qqmgr = plugin_host.get_runtime_context().get_qqbot_manager() - assert isinstance(qqmgr, pkg.qqbot.manager.QQBotManager) - self.bot = qqmgr.bot - - # 向YiriMirai注册 加群申请 事件处理函数 - @qqmgr.bot.on(MemberJoinRequestEvent) - async def process(event: MemberJoinRequestEvent): - assert isinstance(qqmgr, pkg.qqbot.manager.QQBotManager) - if event.group_id == __group_id__: - if any([x in event.message for x in __application_contains__]): - logging.info("自动同意加群申请") - await qqmgr.bot.allow(event) - - self.process = process - - # 插件卸载时触发 - def __del__(self): - # 关闭时向YiriMirai注销 加群申请 事件处理函数 - if self.bot is not None: - self.bot.bus.unsubscribe(MemberJoinRequestEvent, self.process) diff --git a/tests/plugin_examples/cmdcn/cmdcn.py b/tests/plugin_examples/cmdcn/cmdcn.py deleted file mode 100644 index 788b0935..00000000 --- a/tests/plugin_examples/cmdcn/cmdcn.py +++ /dev/null @@ -1,51 +0,0 @@ -from pkg.plugin.models import * -from pkg.plugin.host import EventContext, PluginHost - -""" -基本命令的中文形式支持 -""" - - -__mapping__ = { - "帮助": "help", - "重置": "reset", - "前一次": "last", - "后一次": "next", - "会话内容": "prompt", - "列出会话": "list", - "重新回答": "resend", - "使用量": "usage", - "绘画": "draw", - "版本": "version", - "热重载": "reload", - "热更新": "update", - "配置": "cfg", -} - - -@register(name="CmdCN", description="命令中文支持", version="0.1", author="RockChinQ") -class CmdCnPlugin(Plugin): - - def __init__(self, plugin_host: PluginHost): - pass - - # 私聊发送指令 - @on(PersonCommandSent) - def person_command_sent(self, event: EventContext, **kwargs): - cmd = kwargs['command'] - if cmd in __mapping__: - - # 返回替换后的指令 - event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params'])) - - # 群聊发送指令 - @on(GroupCommandSent) - def group_command_sent(self, event: EventContext, **kwargs): - cmd = kwargs['command'] - if cmd in __mapping__: - - # 返回替换后的指令 - event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params'])) - - def __del__(self): - pass diff --git a/tests/plugin_examples/hello_plugin/main.py b/tests/plugin_examples/hello_plugin/main.py deleted file mode 100644 index 3a5ba8bb..00000000 --- a/tests/plugin_examples/hello_plugin/main.py +++ /dev/null @@ -1,50 +0,0 @@ -from pkg.plugin.models import * -from pkg.plugin.host import EventContext, PluginHost - -""" -在收到私聊或群聊消息"hello"时,回复"hello, <发送者id>!"或"hello, everyone!" -""" - - -# 注册插件 -@register(name="Hello", description="hello world", version="0.1", author="RockChinQ") -class HelloPlugin(Plugin): - - # 插件加载时触发 - # plugin_host (pkg.plugin.host.PluginHost) 提供了与主程序交互的一些方法,详细请查看其源码 - def __init__(self, plugin_host: PluginHost): - pass - - # 当收到个人消息时触发 - @on(PersonNormalMessageReceived) - def person_normal_message_received(self, event: EventContext, **kwargs): - msg = kwargs['text_message'] - if msg == "hello": # 如果消息为hello - - # 输出调试信息 - logging.debug("hello, {}".format(kwargs['sender_id'])) - - # 回复消息 "hello, <发送者id>!" - event.add_return("reply", ["hello, {}!".format(kwargs['sender_id'])]) - - # 阻止该事件默认行为(向接口获取回复) - event.prevent_default() - - # 当收到群消息时触发 - @on(GroupNormalMessageReceived) - def group_normal_message_received(self, event: EventContext, **kwargs): - msg = kwargs['text_message'] - if msg == "hello": # 如果消息为hello - - # 输出调试信息 - logging.debug("hello, {}".format(kwargs['sender_id'])) - - # 回复消息 "hello, everyone!" - event.add_return("reply", ["hello, everyone!"]) - - # 阻止该事件默认行为(向接口获取回复) - event.prevent_default() - - # 插件卸载时触发 - def __del__(self): - pass diff --git a/tests/plugin_examples/urlikethisijustsix/urlt.py b/tests/plugin_examples/urlikethisijustsix/urlt.py deleted file mode 100644 index 8bd1c2f5..00000000 --- a/tests/plugin_examples/urlikethisijustsix/urlt.py +++ /dev/null @@ -1,44 +0,0 @@ -import random - -from mirai import Plain - -from pkg.plugin.models import * -from pkg.plugin.host import EventContext, PluginHost - -""" -私聊或群聊消息为以下列出的一些冒犯性词语时,自动回复__random_reply__中的一句话 -""" - - -__words__ = ['sb', "傻逼", "dinner", "操你妈", "cnm", "fuck you", "fuckyou", - "f*ck you", "弱智", "若智", "答辩", "依托答辩", "低能儿", "nt", "脑瘫", "闹谈", "老坛"] - -__random_reply__ = ['好好好', "啊对对对", "好好好好", "你说得对", "谢谢夸奖"] - - -@register(name="啊对对对", description="你都这样了,我就顺从你吧", version="0.1", author="RockChinQ") -class AdddPlugin(Plugin): - - def __init__(self, plugin_host: PluginHost): - pass - - # 绑定私聊消息事件和群消息事件 - @on(PersonNormalMessageReceived) - @on(GroupNormalMessageReceived) - def normal_message_received(self, event: EventContext, **kwargs): - msg = kwargs['text_message'] - - # 如果消息中包含关键词 - if msg in __words__: - # 随机一个回复 - idx = random.randint(0, len(__random_reply__)-1) - - # 返回回复的消息 - event.add_return("reply", [Plain(__random_reply__[idx])]) - - # 阻止向接口获取回复 - event.prevent_default() - event.prevent_postorder() - - def __del__(self): - pass diff --git a/tests/proxy_test/forward_proxy_test.py b/tests/proxy_test/forward_proxy_test.py deleted file mode 100644 index dbe5399f..00000000 --- a/tests/proxy_test/forward_proxy_test.py +++ /dev/null @@ -1,24 +0,0 @@ -import os - -import openai - -client = openai.Client( - api_key=os.environ["OPENAI_API_KEY"], -) - -openai.proxies = { - 'http': 'http://127.0.0.1:7890', - 'https': 'http://127.0.0.1:7890', -} - -resp = client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "Hello, how are you?", - } - ] -) - -print(resp) \ No newline at end of file diff --git a/tests/repo_regexp_test.py b/tests/repo_regexp_test.py deleted file mode 100644 index 5bf78f9d..00000000 --- a/tests/repo_regexp_test.py +++ /dev/null @@ -1,7 +0,0 @@ -import re - -repo_url = "git@github.com:RockChinQ/WebwlkrPlugin.git" - -repo = re.findall(r'(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)', repo_url) - -print(repo) \ No newline at end of file diff --git a/tests/ssh_client_test/ssh_client.py b/tests/ssh_client_test/ssh_client.py deleted file mode 100644 index a8054a9b..00000000 --- a/tests/ssh_client_test/ssh_client.py +++ /dev/null @@ -1,57 +0,0 @@ -import os -import sys -import paramiko -import time -import select - - -class sshClient: - #创建一个ssh客户端,和服务器连接上,准备发消息 - def __init__(self,host,port,user,password): - self.trans = paramiko.Transport((host, port)) - self.trans.start_client() - self.trans.auth_password(username=user, password=password) - self.channel = self.trans.open_session() - self.channel.get_pty() - self.channel.invoke_shell() - - #给服务器发送一个命令 - def sendCmd(self,cmd): - self.channel.sendall(cmd) - - #接收的时候,有时候服务器处理的比较慢,需要设置一个延时等待一下。 - def recvResponse(self,timeout): - data=b'' - while True: - try: - #使用select,不断的读取数据,直到没有多余的数据了,超时返回。 - readable,w,e= select.select([self.channel],[],[],timeout) - if self.channel in readable: - data = self.channel.recv(1024) - else: - sys.stdout.write(data.decode()) - sys.stdout.flush() - return data.decode() - except TimeoutError: - sys.stdout.write(data.decode()) - sys.stdout.flush() - return data.decode - #关闭客户端 - def close(self): - self.channel.close() - self.trans.close() - -host='host' -port=22#your port -user='root' -pwd='pass' - -ssh = sshClient(host,port,user,pwd) -response = ssh.recvResponse(1) -response = ssh.sendCmd("ls\n") -ssh.sendCmd("cd /home\n") -response = ssh.recvResponse(1) -ssh.sendCmd("ls\n") -response = ssh.recvResponse(1) - -ssh.close() diff --git a/tests/test_session_console.py b/tests/test_session_console.py deleted file mode 100644 index e0446ee8..00000000 --- a/tests/test_session_console.py +++ /dev/null @@ -1,16 +0,0 @@ -import config -import unittest -import pkg.openai.session -import pkg.openai.manager - - -class TestOpenAISession(unittest.TestCase): - def test_session_console(self): - interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) - - session = pkg.openai.session.Session('test') - print(session.append('你好')) - print("#{}#".format(session.prompt)) - - print(session.append('你叫什么名字')) - print("#{}#".format(session.prompt)) diff --git a/tests/token_test/tiktoken_test.py b/tests/token_test/tiktoken_test.py deleted file mode 100644 index c66de117..00000000 --- a/tests/token_test/tiktoken_test.py +++ /dev/null @@ -1,124 +0,0 @@ -import tiktoken -import openai -import json -import os - - -openai.api_key = os.getenv("OPENAI_API_KEY") - - -def encode(text: str, model: str): - import tiktoken - enc = tiktoken.get_encoding("cl100k_base") - assert enc.decode(enc.encode("hello world")) == "hello world" - - # To get the tokeniser corresponding to a specific model in the OpenAI API: - enc = tiktoken.encoding_for_model(model) - - return enc.encode(text) - - -# def ask(prompt: str, model: str = "gpt-3.5-turbo"): -# # To get the tokeniser corresponding to a specific model in the OpenAI API: -# enc = tiktoken.encoding_for_model(model) - -# resp = openai.ChatCompletion.create( -# model=model, -# messages=[ -# { -# "role": "user", -# "content": prompt -# } -# ] -# ) - -# return enc.encode(prompt), enc.encode(resp['choices'][0]['message']['content']), resp - -def ask( - messages: list, - model: str = "gpt-3.5-turbo" -): - enc = tiktoken.encoding_for_model(model) - - resp = openai.ChatCompletion.create( - model=model, - messages=messages - ) - - txt = "" - - for r in messages: - txt += r['role'] + r['content'] + "\n" - - txt += "assistant: " - - return enc.encode(txt), enc.encode(resp['choices'][0]['message']['content']), resp - - -def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"): - """Return the number of tokens used by a list of messages.""" - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif "gpt-3.5-turbo" in model: - print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") - elif "gpt-4" in model: - print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return num_tokens_from_messages(messages, model="gpt-4-0613") - else: - raise NotImplementedError( - f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" - ) - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens - -messages = [ - { - "role": "user", - "content": "你叫什么名字?" - },{ - "role": "assistant", - "content": "我是AI助手,没有具体的名字。你可以叫我GPT-3。有什么可以帮到你的吗?" - },{ - "role": "user", - "content": "你是由谁开发的?" - },{ - "role": "assistant", - "content": "我是由OpenAI开发的,一家人工智能研究实验室。OpenAI的使命是促进人工智能的发展,使其为全人类带来积极影响。我是由OpenAI团队使用GPT-3模型训练而成的。" - },{ - "role": "user", - "content": "很高兴见到你。" - } -] - - -pro, rep, resp=ask(messages) - -print(len(pro), len(rep)) -print(resp) -print(resp['choices'][0]['message']['content']) - -print(num_tokens_from_messages(messages, model="gpt-3.5-turbo")) \ No newline at end of file diff --git a/tips-custom-template.py b/tips-custom-template.py deleted file mode 100644 index 129f957f..00000000 --- a/tips-custom-template.py +++ /dev/null @@ -1,37 +0,0 @@ -import config -# ---------------------------------------------自定义提示语--------------------------------------------- - -# 消息处理出错时向用户发送的提示信息,仅当config.py中hide_exce_info_to_user为True时生效 -# 设置为空字符串时,不发送提示信息 -alter_tip_message = '[bot]err:出错了,请稍后再试' - -# drop策略时,超过限速均值时,丢弃的对话的提示信息,仅当config.py中rate_limitation_strategy为"drop"时生效 -# 若设置为空字符串,则不发送提示信息 -rate_limit_drop_tip = "本分钟对话次数超过限速次数,此对话被丢弃" - -# 只允许同时处理一条消息时,新消息被丢弃时的提示信息 -# 当config.py中的wait_last_done为False时生效 -# 若设置为空字符串,则不发送提示信息 -message_drop_tip = "[bot]当前有一条消息正在处理,请等待处理完成" - -# 命令 !help帮助消息 -help_message = """此机器人通过调用大型语言模型生成回复,不具有情感。 -你可以用自然语言与其交流,回复的消息中[GPT]开头的为模型生成的语言,[bot]开头的为程序提示。 -欢迎到github.com/RockChinQ/QChatGPT 给个star""" - -# 私聊消息超时提示 -reply_message = "[bot]err:请求超时" -# 群聊消息超时提示 -replys_message = "[bot]err:请求超时" - -# 命令权限不足提示 -command_admin_message = "[bot]err:权限不足: " -# 命令无效提示 -command_err_message = "[bot]err:命令不存在:" - -# 会话重置提示 -command_reset_message = "[bot]会话已重置" -command_reset_name_message = "[bot]会话已重置,使用场景预设:" - -# 会话自动重置时的提示 -session_auto_reset_message = "[bot]会话token超限,已自动重置,请重新发送消息"