Skip to content

Commit

Permalink
Merge pull request #958 from zhihuanwang/master
Browse files Browse the repository at this point in the history
增加xAI模型支持
  • Loading branch information
RockChinQ authored Jan 4, 2025
2 parents d214d80 + 0a68a77 commit 5e5a363
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 8 deletions.
25 changes: 25 additions & 0 deletions pkg/core/migrations/m018_xai_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from .. import migration


@migration.migration_class("xai-config", 18)
class XaiConfigMigration(migration.Migration):
"""迁移"""

async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'xai-chat-completions' not in self.ap.provider_cfg.data['requester']

async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
}
self.ap.provider_cfg.data['keys']['xai'] = [
"xai-1234567890"
]

await self.ap.provider_cfg.dump_config()
2 changes: 1 addition & 1 deletion pkg/core/stages/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params
from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params, m018_xai_config


@stage.stage_class("MigrationStage")
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/modelmgr/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...core import app

from . import token
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, xaichatcmpl

FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"

Expand Down
145 changes: 145 additions & 0 deletions pkg/provider/modelmgr/requesters/xaichatcmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import asyncio
import typing
import json
import base64
from typing import AsyncGenerator

import openai
import openai.types.chat.chat_completion as chat_completion
import httpx
import aiohttp
import async_lru

from . import chatcmpl
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image


@requester.requester_class("xai-chat-completions")
class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
"""xAI ChatCompletion API 请求器"""

client: openai.AsyncClient

requester_cfg: dict

def __init__(self, ap: app.Application):
self.ap = ap

self.requester_cfg = self.ap.provider_cfg.data['requester']['xai-chat-completions']

# async def initialize(self):

# self.client = openai.AsyncClient(
# api_key="",
# base_url=self.requester_cfg['base-url'],
# timeout=self.requester_cfg['timeout'],
# http_client=httpx.AsyncClient(
# proxies=self.ap.proxy_mgr.get_forward_proxies()
# )
# )

# async def _req(
# self,
# args: dict,
# ) -> chat_completion.ChatCompletion:
# return await self.client.chat.completions.create(**args)

# async def _make_msg(
# self,
# chat_completion: chat_completion.ChatCompletion,
# ) -> llm_entities.Message:
# chatcmpl_message = chat_completion.choices[0].message.dict()

# # 确保 role 字段存在且不为 None
# if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
# chatcmpl_message['role'] = 'assistant'

# message = llm_entities.Message(**chatcmpl_message)

# return message

# async def _closure(
# self,
# req_messages: list[dict],
# use_model: entities.LLMModelInfo,
# use_funcs: list[tools_entities.LLMFunction] = None,
# ) -> llm_entities.Message:
# self.client.api_key = use_model.token_mgr.get_token()

# args = self.requester_cfg['args'].copy()
# args["model"] = use_model.name if use_model.model_name is None else use_model.model_name

# if use_funcs:
# tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

# if tools:
# args["tools"] = tools

# # 设置此次请求中的messages
# messages = req_messages.copy()

# # 检查vision
# for msg in messages:
# if 'content' in msg and isinstance(msg["content"], list):
# for me in msg["content"]:
# if me["type"] == "image_url":
# me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])

# args["messages"] = messages

# # 发送请求
# resp = await self._req(args)

# # 处理请求结果
# message = await self._make_msg(resp)

# return message

# async def call(
# self,
# model: entities.LLMModelInfo,
# messages: typing.List[llm_entities.Message],
# funcs: typing.List[tools_entities.LLMFunction] = None,
# ) -> llm_entities.Message:
# req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
# for m in messages:
# msg_dict = m.dict(exclude_none=True)
# content = msg_dict.get("content")
# if isinstance(content, list):
# # 检查 content 列表中是否每个部分都是文本
# if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
# # 将所有文本部分合并为一个字符串
# msg_dict["content"] = "\n".join(part["text"] for part in content)
# req_messages.append(msg_dict)

# try:
# return await self._closure(req_messages, model, funcs)
# except asyncio.TimeoutError:
# raise errors.RequesterError('请求超时')
# except openai.BadRequestError as e:
# if 'context_length_exceeded' in e.message:
# raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
# else:
# raise errors.RequesterError(f'请求参数错误: {e.message}')
# except openai.AuthenticationError as e:
# raise errors.RequesterError(f'无效的 api-key: {e.message}')
# except openai.NotFoundError as e:
# raise errors.RequesterError(f'请求路径错误: {e.message}')
# except openai.RateLimitError as e:
# raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
# except openai.APIError as e:
# raise errors.RequesterError(f'请求错误: {e.message}')

# @async_lru.alru_cache(maxsize=128)
# async def get_base64_str(
# self,
# original_url: str,
# ) -> str:
# base64_image, image_format = await image.qq_image_url_to_base64(original_url)
# return f"data:image/{image_format};base64,{base64_image}"
32 changes: 32 additions & 0 deletions templates/metadata/llm-models.json
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,38 @@
"name": "deepseek-coder",
"requester": "deepseek-chat-completions",
"token_mgr": "deepseek"
},
{
"name": "grok-2-latest",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-2",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-2-vision-1212",
"requester": "xai-chat-completions",
"token_mgr": "xai",
"vision_supported": true
},
{
"name": "grok-2-1212",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-vision-beta",
"requester": "xai-chat-completions",
"token_mgr": "xai",
"vision_supported": true
},
{
"name": "grok-beta",
"requester": "xai-chat-completions",
"token_mgr": "xai"
}
]
}
8 changes: 8 additions & 0 deletions templates/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
],
"gitee-ai": [
"XXXXX"
],
"xai": [
"xai-1234567890"
]
},
"requester": {
Expand Down Expand Up @@ -50,6 +53,11 @@
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
},
"xai-chat-completions": {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
}
},
"model": "gpt-4o",
Expand Down
34 changes: 28 additions & 6 deletions templates/schema/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"openai": {
"type": "array",
"title": "OpenAI API 密钥",
"description": "OpenAI API 密钥",
"items": {
"type": "string"
},
Expand All @@ -31,7 +30,6 @@
"anthropic": {
"type": "array",
"title": "Anthropic API 密钥",
"description": "Anthropic API 密钥",
"items": {
"type": "string"
},
Expand All @@ -40,7 +38,6 @@
"moonshot": {
"type": "array",
"title": "Moonshot API 密钥",
"description": "Moonshot API 密钥",
"items": {
"type": "string"
},
Expand All @@ -49,16 +46,22 @@
"deepseek": {
"type": "array",
"title": "DeepSeek API 密钥",
"description": "DeepSeek API 密钥",
"items": {
"type": "string"
},
"default": []
},
"gitee": {
"type": "array",
"title": "Gitee API 密钥",
"description": "Gitee API 密钥",
"title": "Gitee AI API 密钥",
"items": {
"type": "string"
},
"default": []
},
"xai": {
"type": "array",
"title": "xAI API 密钥",
"items": {
"type": "string"
},
Expand Down Expand Up @@ -188,6 +191,25 @@
"default": 120
}
}
},
"xai-chat-completions": {
"type": "object",
"title": "xAI API 请求配置",
"description": "仅可编辑 URL 和 超时时间,额外请求参数不支持可视化编辑,请到编辑器编辑",
"properties": {
"base-url": {
"type": "string",
"title": "API URL"
},
"args": {
"type": "object"
},
"timeout": {
"type": "number",
"title": "API 请求超时时间",
"default": 120
}
}
}
}
},
Expand Down

0 comments on commit 5e5a363

Please sign in to comment.