Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: 适配openai>=1.0.0 #587

Merged
merged 6 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config-template.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
trace_function_calls = False

# 群内回复消息时是否引用原消息
quote_origin = True
quote_origin = False

# 群内回复消息时是否at发送者
at_sender = False
Expand Down
9 changes: 6 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,16 @@ def start(first_time_init=False):

# 配置OpenAI proxy
import openai
openai.proxy = None # 先重置,因为重载后可能需要清除proxy
openai.proxies = None # 先重置,因为重载后可能需要清除proxy
if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None:
openai.proxy = config.openai_config["http_proxy"]
openai.proxies = {
"http": config.openai_config["http_proxy"],
"https": config.openai_config["http_proxy"]
}

# 配置openai api_base
if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None:
openai.api_base = config.openai_config["reverse_proxy"]
openai.base_url = config.openai_config["reverse_proxy"]

# 主启动流程
database = pkg.database.manager.DatabaseManager()
Expand Down
2 changes: 1 addition & 1 deletion override-all.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"size": "256x256"
},
"trace_function_calls": false,
"quote_origin": true,
"quote_origin": false,
"at_sender": false,
"include_image_description": true,
"process_message_timeout": 120,
Expand Down
61 changes: 39 additions & 22 deletions pkg/openai/api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from openai.types.chat import chat_completion_message
import json
import logging

Expand All @@ -13,13 +14,14 @@ class ChatCompletionRequest(RequestBase):
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
"""

model: str
messages: list[dict[str, str]]
kwargs: dict

stopped: bool = False

pending_func_call: dict = None
pending_func_call: chat_completion_message.FunctionCall = None

pending_msg: str

Expand All @@ -46,16 +48,18 @@ def append_message(self, role: str, content: str, name: str=None, function_call:

def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.messages = messages.copy()

self.kwargs = kwargs

self.req_func = openai.ChatCompletion.acreate
self.req_func = self.client.chat.completions.create

self.pending_func_call = None

Expand Down Expand Up @@ -84,80 +88,93 @@ def __next__(self) -> dict:

# 拼接kwargs
args = {**args, **self.kwargs}

from openai.types.chat import chat_completion

resp = self._req(**args)
resp: chat_completion.ChatCompletion = self._req(**args)

choice0 = resp["choices"][0]
choice0 = resp.choices[0]

# 如果不是函数调用,且finish_reason为stop,则停止迭代
if choice0['finish_reason'] == 'stop': # and choice0["finish_reason"] == "stop"
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
self.stopped = True

if 'function_call' in choice0['message']:
self.pending_func_call = choice0['message']['function_call']
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
self.pending_func_call = choice0.message.function_call

self.append_message(
role="assistant",
content=choice0['message']['content'],
function_call=choice0['message']['function_call']
content=choice0.message.content,
function_call=choice0.message.function_call
)

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "function_call",
"content": choice0['message']['content'],
"function_call": choice0['message']['function_call']
"content": choice0.message.content,
"function_call": {
"name": choice0.message.function_call.name,
"arguments": choice0.message.function_call.arguments
}
},
"finish_reason": "function_call"
}
],
"usage": resp["usage"]
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else:

# self.pending_msg += choice0['message']['content']
# 普通回复一定处于最后方,故不用再追加进内部messages

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0['message']['content']
"content": choice0.message.content
},
"finish_reason": choice0["finish_reason"]
"finish_reason": choice0.finish_reason
}
],
"usage": resp["usage"]
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else: # 处理函数调用请求

cp_pending_func_call = self.pending_func_call.copy()

self.pending_func_call = None

func_name = cp_pending_func_call['name']
func_name = cp_pending_func_call.name
arguments = {}

try:

try:
arguments = json.loads(cp_pending_func_call['arguments'])
arguments = json.loads(cp_pending_func_call.arguments)
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = get_func_schema(func_name)

arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call['arguments']
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
}

logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
Expand Down
45 changes: 17 additions & 28 deletions pkg/openai/api/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from openai.types import completion, completion_choice

from .model import RequestBase

Expand All @@ -17,10 +18,12 @@ class CompletionRequest(RequestBase):

def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.prompt = ""

Expand All @@ -31,7 +34,7 @@ def __init__(

self.kwargs = kwargs

self.req_func = openai.Completion.acreate
self.req_func = self.client.completions.create

def __iter__(self):
return self
Expand Down Expand Up @@ -63,49 +66,35 @@ def __next__(self) -> dict:
if self.stopped:
raise StopIteration()

resp = self._req(
resp: completion.Completion = self._req(
model=self.model,
prompt=self.prompt,
**self.kwargs
)

if resp["choices"][0]["finish_reason"] == "stop":
if resp.choices[0].finish_reason == "stop":
self.stopped = True

choice0 = resp["choices"][0]
choice0: completion_choice.CompletionChoice = resp.choices[0]

self.prompt += choice0["text"]
self.prompt += choice0.text

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0["text"]
"content": choice0.text
},
"finish_reason": choice0["finish_reason"]
"finish_reason": choice0.finish_reason
}
],
"usage": resp["usage"]
}

if __name__ == "__main__":
import os

openai.api_key = os.environ["OPENAI_API_KEY"]

for resp in CompletionRequest(
model="text-davinci-003",
messages=[
{
"role": "user",
"content": "Hello, who are you?"
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
]
):
print(resp)
if resp["choices"][0]["finish_reason"] == "stop":
break
}
36 changes: 7 additions & 29 deletions pkg/openai/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

class RequestBase:

client: openai.Client

req_func: callable

def __init__(self, *args, **kwargs):
Expand All @@ -17,41 +19,17 @@ def _next_key(self):
import pkg.utils.context as context
switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
openai.api_key = context.get_openai_manager().key_mgr.get_using_key()
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()

def _req(self, **kwargs):
"""处理代理问题"""
import config

ret: dict = {}
exception: Exception = None

async def awrapper(**kwargs):
nonlocal ret, exception

try:
ret = await self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))

if config.switch_strategy == 'active':
self._next_key()

return ret
except Exception as e:
exception = e

loop = asyncio.new_event_loop()

thr = threading.Thread(
target=loop.run_until_complete,
args=(awrapper(**kwargs),)
)

thr.start()
thr.join()
ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))

if exception is not None:
raise exception
if config.switch_strategy == 'active':
self._next_key()

return ret

Expand Down
8 changes: 6 additions & 2 deletions pkg/openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ class OpenAIInteract:
"size": "256x256",
}

client: openai.Client = None

def __init__(self, api_key: str):

self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer()

# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())

openai.api_key = self.key_mgr.get_using_key()
self.client = openai.Client(
api_key=self.key_mgr.get_using_key()
)

pkg.utils.context.set_openai_manager(self)

Expand All @@ -48,7 +52,7 @@ def request_completion(self, messages: list):
cp_parmas = config.completion_api_params.copy()
del cp_parmas['model']

request = select_request_cls(model, messages, cp_parmas)
request = select_request_cls(self.client, model, messages, cp_parmas)

# 请求接口
for resp in request:
Expand Down
9 changes: 4 additions & 5 deletions pkg/openai/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
Completion - text-davinci-003 等模型
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
"""
import openai, logging, threading, asyncio
import openai.error as aiE
import tiktoken
import openai

from pkg.openai.api.model import RequestBase
from pkg.openai.api.completion import CompletionRequest
Expand Down Expand Up @@ -53,11 +52,11 @@
}


def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return ChatCompletionRequest(model_name, messages, **args)
return ChatCompletionRequest(client, model_name, messages, **args)
elif model_name in COMPLETION_MODELS:
return CompletionRequest(model_name, messages, **args)
return CompletionRequest(client, model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))


Expand Down
Loading
Loading