Skip to content

Commit

Permalink
refactor: 修改引入风格
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Nov 13, 2023
1 parent e3b2807 commit 665de5d
Show file tree
Hide file tree
Showing 47 changed files with 324 additions and 364 deletions.
16 changes: 8 additions & 8 deletions pkg/audit/gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import requests

import pkg.utils.context
import pkg.utils.updater
from ..utils import context
from ..utils import updater


class DataGatherer:
Expand All @@ -33,7 +33,7 @@ class DataGatherer:
def __init__(self):
self.load_from_db()
try:
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
self.version_str = updater.get_current_tag() # 从updater模块获取版本号
except:
pass

Expand All @@ -47,7 +47,7 @@ def report_to_server(self, subservice_name: str, count: int):
def thread_func():

try:
config = pkg.utils.context.get_config()
config = context.get_config()
if not config.report_usage:
return
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter))
Expand All @@ -64,7 +64,7 @@ def get_usage(self, key_md5):
def report_text_model_usage(self, model, total_tokens):
"""调用方报告文字模型请求文字使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存

if key_md5 not in self.usage:
self.usage[key_md5] = {}
Expand All @@ -84,7 +84,7 @@ def report_text_model_usage(self, model, total_tokens):
def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5()

if key_md5 not in self.usage:
self.usage[key_md5] = {}
Expand Down Expand Up @@ -131,9 +131,9 @@ def get_total_text_length(self):
return total

def dump_to_db(self):
pkg.utils.context.get_database_manager().dump_usage_json(self.usage)
context.get_database_manager().dump_usage_json(self.usage)

def load_from_db(self):
json_str = pkg.utils.context.get_database_manager().load_usage_json()
json_str = context.get_database_manager().load_usage_json()
if json_str is not None:
self.usage = json.loads(json_str)
9 changes: 4 additions & 5 deletions pkg/database/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import json
import logging
import time
from sqlite3 import Cursor

import sqlite3

import pkg.utils.context
from ..utils import context


class DatabaseManager:
Expand All @@ -22,7 +21,7 @@ def __init__(self):

self.reconnect()

pkg.utils.context.set_database_manager(self)
context.set_database_manager(self)

# 连接到数据库文件
def reconnect(self):
Expand All @@ -33,7 +32,7 @@ def reconnect(self):
def close(self):
self.conn.close()

def __execute__(self, *args, **kwargs) -> Cursor:
def __execute__(self, *args, **kwargs) -> sqlite3.Cursor:
# logging.debug('SQL: {}'.format(sql))
logging.debug('SQL: {}'.format(args))
c = self.cursor.execute(*args, **kwargs)
Expand Down Expand Up @@ -145,7 +144,7 @@ def set_session_expired(self, session_name: str, create_timestamp: int):
# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
config = context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
Expand Down
17 changes: 8 additions & 9 deletions pkg/openai/api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import openai
from openai.types.chat import chat_completion_message
import json
import logging

from .model import RequestBase
import openai
from openai.types.chat import chat_completion_message

from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError
from .model import RequestBase
from .. import funcmgr


class ChatCompletionRequest(RequestBase):
Expand Down Expand Up @@ -81,7 +81,7 @@ def __next__(self) -> dict:
"messages": self.messages,
}

funcs = get_func_schema_list()
funcs = funcmgr.get_func_schema_list()

if len(funcs) > 0:
args['functions'] = funcs
Expand Down Expand Up @@ -171,7 +171,7 @@ def __next__(self) -> dict:
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = get_func_schema(func_name)
func_schema = funcmgr.get_func_schema(func_name)

arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
Expand All @@ -182,7 +182,7 @@ def __next__(self) -> dict:
# 执行函数调用
ret = ""
try:
ret = execute_function(func_name, arguments)
ret = funcmgr.execute_function(func_name, arguments)

logging.info("函数执行完成。")
except Exception as e:
Expand Down Expand Up @@ -216,6 +216,5 @@ def __next__(self) -> dict:
}
}

except ContentFunctionNotFoundError:
except funcmgr.ContentFunctionNotFoundError:
raise Exception("没有找到函数: {}".format(func_name))

4 changes: 2 additions & 2 deletions pkg/openai/api/completion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import openai
from openai.types import completion, completion_choice

from .model import RequestBase
from . import model


class CompletionRequest(RequestBase):
class CompletionRequest(model.RequestBase):
"""调用Completion接口的请求类。
调用方可以一直next completion直到finish_reason为stop。
Expand Down
2 changes: 0 additions & 2 deletions pkg/openai/api/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# 定义不同接口请求的模型
import threading
import asyncio
import logging

import openai
Expand Down
3 changes: 1 addition & 2 deletions pkg/openai/funcmgr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# 封装了function calling的一些支持函数
import logging


from pkg.plugin import host
from ..plugin import host


class ContentFunctionNotFoundError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions pkg/openai/keymgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import hashlib
import logging

import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models


class KeysManager:
Expand Down
29 changes: 14 additions & 15 deletions pkg/openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import openai

import pkg.openai.keymgr
import pkg.utils.context
import pkg.audit.gatherer
from pkg.openai.modelmgr import select_request_cls

from pkg.openai.api.model import RequestBase
from ..openai import keymgr
from ..utils import context
from ..audit import gatherer
from ..openai import modelmgr
from ..openai.api import model as api_model


class OpenAIInteract:
Expand All @@ -16,9 +15,9 @@ class OpenAIInteract:
将文字接口和图片接口封装供调用方使用
"""

key_mgr: pkg.openai.keymgr.KeysManager = None
key_mgr: keymgr.KeysManager = None

audit_mgr: pkg.audit.gatherer.DataGatherer = None
audit_mgr: gatherer.DataGatherer = None

default_image_api_params = {
"size": "256x256",
Expand All @@ -28,31 +27,31 @@ class OpenAIInteract:

def __init__(self, api_key: str):

self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
self.key_mgr = keymgr.KeysManager(api_key)
self.audit_mgr = gatherer.DataGatherer()

# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())

self.client = openai.Client(
api_key=self.key_mgr.get_using_key()
)

pkg.utils.context.set_openai_manager(self)
context.set_openai_manager(self)

def request_completion(self, messages: list):
"""请求补全接口回复=
"""
# 选择接口请求类
config = pkg.utils.context.get_config()
config = context.get_config()

request: RequestBase
request: api_model.RequestBase

model: str = config.completion_api_params['model']

cp_parmas = config.completion_api_params.copy()
del cp_parmas['model']

request = select_request_cls(self.client, model, messages, cp_parmas)
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)

# 请求接口
for resp in request:
Expand All @@ -74,7 +73,7 @@ def request_image(self, prompt) -> dict:
Returns:
dict: 响应
"""
config = pkg.utils.context.get_config()
config = context.get_config()
params = config.image_api_params

response = openai.Image.create(
Expand Down
12 changes: 6 additions & 6 deletions pkg/openai/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import tiktoken
import openai

from pkg.openai.api.model import RequestBase
from pkg.openai.api.completion import CompletionRequest
from pkg.openai.api.chat_completion import ChatCompletionRequest
from ..openai.api import model as api_model
from ..openai.api import completion as api_completion
from ..openai.api import chat_completion as api_chat_completion

COMPLETION_MODELS = {
"text-davinci-003", # legacy
Expand Down Expand Up @@ -60,11 +60,11 @@
}


def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return ChatCompletionRequest(client, model_name, messages, **args)
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
elif model_name in COMPLETION_MODELS:
return CompletionRequest(client, model_name, messages, **args)
return api_completion.CompletionRequest(client, model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))


Expand Down
Loading

0 comments on commit 665de5d

Please sign in to comment.