Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加xAI模型支持 #958

Merged
merged 2 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pkg/core/migrations/m018_xai_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from .. import migration


@migration.migration_class("xai-config", 18)
class XaiConfigMigration(migration.Migration):
"""迁移"""

async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'xai-chat-completions' not in self.ap.provider_cfg.data['requester']

async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
}
self.ap.provider_cfg.data['keys']['xai'] = [
"xai-1234567890"
]

await self.ap.provider_cfg.dump_config()
2 changes: 1 addition & 1 deletion pkg/core/stages/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params
from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params, m018_xai_config


@stage.stage_class("MigrationStage")
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/modelmgr/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...core import app

from . import token
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl
from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, xaichatcmpl

FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"

Expand Down
145 changes: 145 additions & 0 deletions pkg/provider/modelmgr/requesters/xaichatcmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import asyncio
import typing
import json
import base64
from typing import AsyncGenerator

import openai
import openai.types.chat.chat_completion as chat_completion
import httpx
import aiohttp
import async_lru

from . import chatcmpl
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image


@requester.requester_class("xai-chat-completions")
class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
"""xAI ChatCompletion API 请求器"""

client: openai.AsyncClient

requester_cfg: dict

def __init__(self, ap: app.Application):
self.ap = ap

self.requester_cfg = self.ap.provider_cfg.data['requester']['xai-chat-completions']

# async def initialize(self):

# self.client = openai.AsyncClient(
# api_key="",
# base_url=self.requester_cfg['base-url'],
# timeout=self.requester_cfg['timeout'],
# http_client=httpx.AsyncClient(
# proxies=self.ap.proxy_mgr.get_forward_proxies()
# )
# )

# async def _req(
# self,
# args: dict,
# ) -> chat_completion.ChatCompletion:
# return await self.client.chat.completions.create(**args)

# async def _make_msg(
# self,
# chat_completion: chat_completion.ChatCompletion,
# ) -> llm_entities.Message:
# chatcmpl_message = chat_completion.choices[0].message.dict()

# # 确保 role 字段存在且不为 None
# if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
# chatcmpl_message['role'] = 'assistant'

# message = llm_entities.Message(**chatcmpl_message)

# return message

# async def _closure(
# self,
# req_messages: list[dict],
# use_model: entities.LLMModelInfo,
# use_funcs: list[tools_entities.LLMFunction] = None,
# ) -> llm_entities.Message:
# self.client.api_key = use_model.token_mgr.get_token()

# args = self.requester_cfg['args'].copy()
# args["model"] = use_model.name if use_model.model_name is None else use_model.model_name

# if use_funcs:
# tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

# if tools:
# args["tools"] = tools

# # 设置此次请求中的messages
# messages = req_messages.copy()

# # 检查vision
# for msg in messages:
# if 'content' in msg and isinstance(msg["content"], list):
# for me in msg["content"]:
# if me["type"] == "image_url":
# me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])

# args["messages"] = messages

# # 发送请求
# resp = await self._req(args)

# # 处理请求结果
# message = await self._make_msg(resp)

# return message

# async def call(
# self,
# model: entities.LLMModelInfo,
# messages: typing.List[llm_entities.Message],
# funcs: typing.List[tools_entities.LLMFunction] = None,
# ) -> llm_entities.Message:
# req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
# for m in messages:
# msg_dict = m.dict(exclude_none=True)
# content = msg_dict.get("content")
# if isinstance(content, list):
# # 检查 content 列表中是否每个部分都是文本
# if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
# # 将所有文本部分合并为一个字符串
# msg_dict["content"] = "\n".join(part["text"] for part in content)
# req_messages.append(msg_dict)

# try:
# return await self._closure(req_messages, model, funcs)
# except asyncio.TimeoutError:
# raise errors.RequesterError('请求超时')
# except openai.BadRequestError as e:
# if 'context_length_exceeded' in e.message:
# raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
# else:
# raise errors.RequesterError(f'请求参数错误: {e.message}')
# except openai.AuthenticationError as e:
# raise errors.RequesterError(f'无效的 api-key: {e.message}')
# except openai.NotFoundError as e:
# raise errors.RequesterError(f'请求路径错误: {e.message}')
# except openai.RateLimitError as e:
# raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
# except openai.APIError as e:
# raise errors.RequesterError(f'请求错误: {e.message}')

# @async_lru.alru_cache(maxsize=128)
# async def get_base64_str(
# self,
# original_url: str,
# ) -> str:
# base64_image, image_format = await image.qq_image_url_to_base64(original_url)
# return f"data:image/{image_format};base64,{base64_image}"
32 changes: 32 additions & 0 deletions templates/metadata/llm-models.json
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,38 @@
"name": "deepseek-coder",
"requester": "deepseek-chat-completions",
"token_mgr": "deepseek"
},
{
"name": "grok-2-latest",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-2",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-2-vision-1212",
"requester": "xai-chat-completions",
"token_mgr": "xai",
"vision_supported": true
},
{
"name": "grok-2-1212",
"requester": "xai-chat-completions",
"token_mgr": "xai"
},
{
"name": "grok-vision-beta",
"requester": "xai-chat-completions",
"token_mgr": "xai",
"vision_supported": true
},
{
"name": "grok-beta",
"requester": "xai-chat-completions",
"token_mgr": "xai"
}
]
}
8 changes: 8 additions & 0 deletions templates/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
],
"gitee-ai": [
"XXXXX"
],
"xai": [
"xai-1234567890"
]
},
"requester": {
Expand Down Expand Up @@ -50,6 +53,11 @@
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
},
"xai-chat-completions": {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
}
},
"model": "gpt-4o",
Expand Down
34 changes: 28 additions & 6 deletions templates/schema/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"openai": {
"type": "array",
"title": "OpenAI API 密钥",
"description": "OpenAI API 密钥",
"items": {
"type": "string"
},
Expand All @@ -31,7 +30,6 @@
"anthropic": {
"type": "array",
"title": "Anthropic API 密钥",
"description": "Anthropic API 密钥",
"items": {
"type": "string"
},
Expand All @@ -40,7 +38,6 @@
"moonshot": {
"type": "array",
"title": "Moonshot API 密钥",
"description": "Moonshot API 密钥",
"items": {
"type": "string"
},
Expand All @@ -49,16 +46,22 @@
"deepseek": {
"type": "array",
"title": "DeepSeek API 密钥",
"description": "DeepSeek API 密钥",
"items": {
"type": "string"
},
"default": []
},
"gitee": {
"type": "array",
"title": "Gitee API 密钥",
"description": "Gitee API 密钥",
"title": "Gitee AI API 密钥",
"items": {
"type": "string"
},
"default": []
},
"xai": {
"type": "array",
"title": "xAI API 密钥",
"items": {
"type": "string"
},
Expand Down Expand Up @@ -188,6 +191,25 @@
"default": 120
}
}
},
"xai-chat-completions": {
"type": "object",
"title": "xAI API 请求配置",
"description": "仅可编辑 URL 和 超时时间,额外请求参数不支持可视化编辑,请到编辑器编辑",
"properties": {
"base-url": {
"type": "string",
"title": "API URL"
},
"args": {
"type": "object"
},
"timeout": {
"type": "number",
"title": "API 请求超时时间",
"default": 120
}
}
}
}
},
Expand Down