Skip to content

Commit

Permalink
Merge pull request #587 from RockChinQ/hotfix/openai-1.0-adaptation
Browse files Browse the repository at this point in the history
Feat: 适配openai>=1.0.0
  • Loading branch information
RockChinQ authored Nov 10, 2023
2 parents 57de96e + 8a1d4fe commit 45e4096
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 97 deletions.
2 changes: 1 addition & 1 deletion config-template.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
trace_function_calls = False

# 群内回复消息时是否引用原消息
quote_origin = True
quote_origin = False

# 群内回复消息时是否at发送者
at_sender = False
Expand Down
9 changes: 6 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,16 @@ def start(first_time_init=False):

# 配置OpenAI proxy
import openai
openai.proxy = None # 先重置,因为重载后可能需要清除proxy
openai.proxies = None # 先重置,因为重载后可能需要清除proxy
if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None:
openai.proxy = config.openai_config["http_proxy"]
openai.proxies = {
"http": config.openai_config["http_proxy"],
"https": config.openai_config["http_proxy"]
}

# 配置openai api_base
if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None:
openai.api_base = config.openai_config["reverse_proxy"]
openai.base_url = config.openai_config["reverse_proxy"]

# 主启动流程
database = pkg.database.manager.DatabaseManager()
Expand Down
2 changes: 1 addition & 1 deletion override-all.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"size": "256x256"
},
"trace_function_calls": false,
"quote_origin": true,
"quote_origin": false,
"at_sender": false,
"include_image_description": true,
"process_message_timeout": 120,
Expand Down
61 changes: 39 additions & 22 deletions pkg/openai/api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from openai.types.chat import chat_completion_message
import json
import logging

Expand All @@ -13,13 +14,14 @@ class ChatCompletionRequest(RequestBase):
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
"""

model: str
messages: list[dict[str, str]]
kwargs: dict

stopped: bool = False

pending_func_call: dict = None
pending_func_call: chat_completion_message.FunctionCall = None

pending_msg: str

Expand All @@ -46,16 +48,18 @@ def append_message(self, role: str, content: str, name: str=None, function_call:

def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.messages = messages.copy()

self.kwargs = kwargs

self.req_func = openai.ChatCompletion.acreate
self.req_func = self.client.chat.completions.create

self.pending_func_call = None

Expand Down Expand Up @@ -84,80 +88,93 @@ def __next__(self) -> dict:

# 拼接kwargs
args = {**args, **self.kwargs}

from openai.types.chat import chat_completion

resp = self._req(**args)
resp: chat_completion.ChatCompletion = self._req(**args)

choice0 = resp["choices"][0]
choice0 = resp.choices[0]

# 如果不是函数调用,且finish_reason为stop,则停止迭代
if choice0['finish_reason'] == 'stop': # and choice0["finish_reason"] == "stop"
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
self.stopped = True

if 'function_call' in choice0['message']:
self.pending_func_call = choice0['message']['function_call']
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
self.pending_func_call = choice0.message.function_call

self.append_message(
role="assistant",
content=choice0['message']['content'],
function_call=choice0['message']['function_call']
content=choice0.message.content,
function_call=choice0.message.function_call
)

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "function_call",
"content": choice0['message']['content'],
"function_call": choice0['message']['function_call']
"content": choice0.message.content,
"function_call": {
"name": choice0.message.function_call.name,
"arguments": choice0.message.function_call.arguments
}
},
"finish_reason": "function_call"
}
],
"usage": resp["usage"]
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else:

# self.pending_msg += choice0['message']['content']
# 普通回复一定处于最后方,故不用再追加进内部messages

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0['message']['content']
"content": choice0.message.content
},
"finish_reason": choice0["finish_reason"]
"finish_reason": choice0.finish_reason
}
],
"usage": resp["usage"]
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else: # 处理函数调用请求

cp_pending_func_call = self.pending_func_call.copy()

self.pending_func_call = None

func_name = cp_pending_func_call['name']
func_name = cp_pending_func_call.name
arguments = {}

try:

try:
arguments = json.loads(cp_pending_func_call['arguments'])
arguments = json.loads(cp_pending_func_call.arguments)
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = get_func_schema(func_name)

arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call['arguments']
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
}

logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
Expand Down
45 changes: 17 additions & 28 deletions pkg/openai/api/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from openai.types import completion, completion_choice

from .model import RequestBase

Expand All @@ -17,10 +18,12 @@ class CompletionRequest(RequestBase):

def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.prompt = ""

Expand All @@ -31,7 +34,7 @@ def __init__(

self.kwargs = kwargs

self.req_func = openai.Completion.acreate
self.req_func = self.client.completions.create

def __iter__(self):
return self
Expand Down Expand Up @@ -63,49 +66,35 @@ def __next__(self) -> dict:
if self.stopped:
raise StopIteration()

resp = self._req(
resp: completion.Completion = self._req(
model=self.model,
prompt=self.prompt,
**self.kwargs
)

if resp["choices"][0]["finish_reason"] == "stop":
if resp.choices[0].finish_reason == "stop":
self.stopped = True

choice0 = resp["choices"][0]
choice0: completion_choice.CompletionChoice = resp.choices[0]

self.prompt += choice0["text"]
self.prompt += choice0.text

return {
"id": resp["id"],
"id": resp.id,
"choices": [
{
"index": choice0["index"],
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0["text"]
"content": choice0.text
},
"finish_reason": choice0["finish_reason"]
"finish_reason": choice0.finish_reason
}
],
"usage": resp["usage"]
}

if __name__ == "__main__":
import os

openai.api_key = os.environ["OPENAI_API_KEY"]

for resp in CompletionRequest(
model="text-davinci-003",
messages=[
{
"role": "user",
"content": "Hello, who are you?"
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
]
):
print(resp)
if resp["choices"][0]["finish_reason"] == "stop":
break
}
36 changes: 7 additions & 29 deletions pkg/openai/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

class RequestBase:

client: openai.Client

req_func: callable

def __init__(self, *args, **kwargs):
Expand All @@ -17,41 +19,17 @@ def _next_key(self):
import pkg.utils.context as context
switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
openai.api_key = context.get_openai_manager().key_mgr.get_using_key()
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()

def _req(self, **kwargs):
"""处理代理问题"""
import config

ret: dict = {}
exception: Exception = None

async def awrapper(**kwargs):
nonlocal ret, exception

try:
ret = await self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))

if config.switch_strategy == 'active':
self._next_key()

return ret
except Exception as e:
exception = e

loop = asyncio.new_event_loop()

thr = threading.Thread(
target=loop.run_until_complete,
args=(awrapper(**kwargs),)
)

thr.start()
thr.join()
ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))

if exception is not None:
raise exception
if config.switch_strategy == 'active':
self._next_key()

return ret

Expand Down
8 changes: 6 additions & 2 deletions pkg/openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ class OpenAIInteract:
"size": "256x256",
}

client: openai.Client = None

def __init__(self, api_key: str):

self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer()

# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())

openai.api_key = self.key_mgr.get_using_key()
self.client = openai.Client(
api_key=self.key_mgr.get_using_key()
)

pkg.utils.context.set_openai_manager(self)

Expand All @@ -48,7 +52,7 @@ def request_completion(self, messages: list):
cp_parmas = config.completion_api_params.copy()
del cp_parmas['model']

request = select_request_cls(model, messages, cp_parmas)
request = select_request_cls(self.client, model, messages, cp_parmas)

# 请求接口
for resp in request:
Expand Down
9 changes: 4 additions & 5 deletions pkg/openai/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
Completion - text-davinci-003 等模型
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
"""
import openai, logging, threading, asyncio
import openai.error as aiE
import tiktoken
import openai

from pkg.openai.api.model import RequestBase
from pkg.openai.api.completion import CompletionRequest
Expand Down Expand Up @@ -53,11 +52,11 @@
}


def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return ChatCompletionRequest(model_name, messages, **args)
return ChatCompletionRequest(client, model_name, messages, **args)
elif model_name in COMPLETION_MODELS:
return CompletionRequest(model_name, messages, **args)
return CompletionRequest(client, model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))


Expand Down
Loading

0 comments on commit 45e4096

Please sign in to comment.