Skip to content

Commit

Permalink
feat: 添加对 Gitee AI 的支持
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Nov 21, 2024
1 parent 753066c commit 875adfc
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 23 deletions.
26 changes: 26 additions & 0 deletions pkg/core/migrations/m015_gitee_ai_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from .. import migration


@migration.migration_class("gitee-ai-config", 15)
class GiteeAIConfigMigration(migration.Migration):
"""迁移"""

async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']

async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
}

self.ap.provider_cfg.data['keys']['gitee-ai'] = [
"XXXXX"
]

await self.ap.provider_cfg.dump_config()
2 changes: 2 additions & 0 deletions pkg/core/stages/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config


@stage.stage_class("MigrationStage")
Expand All @@ -28,3 +29,4 @@ async def run(self, ap: app.Application):

if await migration_instance.need_migrate():
await migration_instance.run()
print(f'已执行迁移 {migration_instance.name}')
4 changes: 2 additions & 2 deletions pkg/provider/modelmgr/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pydantic

from . import api
from . import requester
from . import token


Expand All @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel):

token_mgr: token.TokenManager

requester: api.LLMAPIRequester
requester: requester.LLMAPIRequester

tool_call_supported: typing.Optional[bool] = False

Expand Down
14 changes: 7 additions & 7 deletions pkg/provider/modelmgr/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import aiohttp

from . import entities
from . import entities, requester
from ...core import app

from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat
from . import token
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl

FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"

Expand All @@ -18,7 +18,7 @@ class ModelManager:

model_list: list[entities.LLMModelInfo]

requesters: dict[str, api.LLMAPIRequester]
requesters: dict[str, requester.LLMAPIRequester]

token_mgrs: dict[str, token.TokenManager]

Expand All @@ -42,7 +42,7 @@ async def initialize(self):
for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v)

for api_cls in api.preregistered_requesters:
for api_cls in requester.preregistered_requesters:
api_inst = api_cls(self.ap)
await api_inst.initialize()
self.requesters[api_inst.name] = api_inst
Expand Down Expand Up @@ -94,15 +94,15 @@ async def initialize(self):

model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)

model_info = entities.LLMModelInfo(
name=model['name'],
model_name=model_name,
token_mgr=token_mgr,
requester=requester,
requester=req,
tool_call_supported=tool_call_supported,
vision_supported=vision_supported
)
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@

import anthropic

from .. import api, entities, errors
from .. import entities, errors, requester

from .. import api, entities, errors
from .. import entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image


@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
@requester.requester_class("anthropic-messages")
class AnthropicMessages(requester.LLMAPIRequester):
"""Anthropic Messages API 请求器"""

client: anthropic.AsyncAnthropic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import aiohttp
import async_lru

from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image


@api.requester_class("openai-chat-completions")
class OpenAIChatCompletions(api.LLMAPIRequester):
@requester.requester_class("openai-chat-completions")
class OpenAIChatCompletions(requester.LLMAPIRequester):
"""OpenAI ChatCompletion API 请求器"""

client: openai.AsyncClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ....core import app

from . import chatcmpl
from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities


@api.requester_class("deepseek-chat-completions")
@requester.requester_class("deepseek-chat-completions")
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""

Expand Down
53 changes: 53 additions & 0 deletions pkg/provider/modelmgr/requesters/giteeaichatcmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import json

import asyncio
import aiohttp
import typing

from . import chatcmpl
from .. import entities, errors, requester
from ....core import app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from .. import entities as modelmgr_entities


@requester.requester_class("gitee-ai-chat-completions")
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器"""

def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = ap.provider_cfg.data['requester']['gitee-ai-chat-completions'].copy()

async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()

args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name

if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

if tools:
args["tools"] = tools

# gitee 不支持多模态,把content都转换成纯文字
for m in req_messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])

args["messages"] = req_messages

resp = await self._req(args)

message = await self._make_msg(resp)

return message
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ....core import app

from . import chatcmpl
from .. import api, entities, errors
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities


@api.requester_class("moonshot-chat-completions")
@requester.requester_class("moonshot-chat-completions")
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import async_lru
import ollama

from .. import api, entities, errors
from .. import entities, errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
Expand All @@ -17,8 +17,8 @@
REQUESTER_NAME: str = "ollama-chat"


@api.requester_class(REQUESTER_NAME)
class OllamaChatCompletions(api.LLMAPIRequester):
@requester.requester_class(REQUESTER_NAME)
class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
request_cfg: dict
Expand Down
8 changes: 8 additions & 0 deletions templates/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
],
"deepseek": [
"sk-1234567890"
],
"gitee-ai": [
"XXXXX"
]
},
"requester": {
Expand Down Expand Up @@ -42,6 +45,11 @@
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
},
"gitee-ai-chat-completions": {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
}
},
"model": "gpt-3.5-turbo",
Expand Down

0 comments on commit 875adfc

Please sign in to comment.