From d2fb3f5130fab1276b6a390d533465eed132b522 Mon Sep 17 00:00:00 2001
From: JobSmithManipulation
<143315462+JobSmithManipulation@users.noreply.github.com>
Date: Mon, 30 Sep 2024 18:14:20 +0800
Subject: [PATCH] support sequence2txt and tts model in Xinference
---
api/db/services/llm_service.py | 2 +-
rag/llm/__init__.py | 18 ++-
rag/llm/sequence2txt_model.py | 106 +++++++++++++++---
rag/llm/tts_model.py | 60 ++++++++++
.../setting-model/ollama-modal/index.tsx | 4 +-
5 files changed, 164 insertions(+), 26 deletions(-)
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 89e5593b36..74515f9951 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -195,7 +195,7 @@ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
- assert self.mdl, "Can't find mole for {}/{}/{}".format(
+ assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name):
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 441a2a553b..ac732d021a 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -47,10 +47,9 @@
"Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
- "HuggingFace":HuggingFaceEmbed,
+ "HuggingFace": HuggingFaceEmbed,
}
-
CvModel = {
"OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4,
@@ -64,14 +63,13 @@
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
- "StepFun":StepFunCV,
+ "StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV
}
-
ChatModel = {
"OpenAI": GptTurbo,
"Azure-OpenAI": AzureChat,
@@ -99,7 +97,7 @@
"LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat,
- "Upstage":UpstageChat,
+ "Upstage": UpstageChat,
"novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat,
@@ -111,7 +109,6 @@
"Google Cloud": GoogleChat,
}
-
RerankModel = {
"BAAI": DefaultRerank,
"Jina": JinaRerank,
@@ -127,7 +124,6 @@
"Voyage AI": VoyageRerank
}
-
Seq2txtModel = {
"OpenAI": GPTSeq2txt,
"Tongyi-Qianwen": QWenSeq2txt,
@@ -140,6 +136,8 @@
TTSModel = {
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS,
- "OpenAI":OpenAITTS,
- "XunFei Spark":SparkTTS
-}
\ No newline at end of file
+ "OpenAI": OpenAITTS,
+ "XunFei Spark": SparkTTS,
+ "Xinference": XinferenceTTS,
+ "Ollama": OllamaTTS
+}
diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py
index a3f7f5af11..c5fca223da 100644
--- a/rag/llm/sequence2txt_model.py
+++ b/rag/llm/sequence2txt_model.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
@@ -25,6 +26,7 @@
import base64
import re
+
class Base(ABC):
def __init__(self, key, model_name):
pass
@@ -36,8 +38,8 @@ def transcription(self, audio, **kwargs):
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
-
- def audio2base64(self,audio):
+
+ def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
@@ -77,31 +79,109 @@ def transcription(self, audio, format):
return "**ERROR**: " + result.message, 0
-class OllamaSeq2txt(Base):
+class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
- self.client = Client(host=kwargs["base_url"])
+ self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
-class AzureSeq2txt(Base):
- def __init__(self, key, model_name, lang="Chinese", **kwargs):
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
+class OllamaSeq2txt(Base):
+ def __init__(self, base_url="http://localhost:11434", model_name="whisper-small"):
+ if base_url.split("/")[-1] != "api":
+ base_url = os.path.join(base_url, "api")
+ self.base_url = base_url
self.model_name = model_name
- self.lang = lang
+
+ def transcription(self, audio, format="wav", **kwargs):
+ b64_audio = self.audio2base64(audio)
+
+ payload = {
+ "model": self.model_name,
+ "audio_data": b64_audio,
+ "format": format,
+ }
+
+ try:
+ response = requests.post(
+ f"{self.base_url}/transcription",
+ json=payload
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ if 'text' in result:
+ transcription_text = result['text'].strip()
+ return transcription_text, num_tokens_from_string(transcription_text)
+ else:
+ return "**ERROR**: Failed to retrieve transcription.", 0
+
+ except requests.exceptions.RequestException as e:
+ return f"**ERROR**: {str(e)}", 0
+
+ def audio2base64(self, audio):
+ if isinstance(audio, bytes):
+ return base64.b64encode(audio).decode("utf-8")
+ if isinstance(audio, io.BytesIO):
+ return base64.b64encode(audio.getvalue()).decode("utf-8")
+ raise TypeError("The input audio file should be in binary format.")
class XinferenceSeq2txt(Base):
- def __init__(self, key, model_name="", base_url=""):
- if base_url.split("/")[-1] != "v1":
- base_url = os.path.join(base_url, "v1")
- self.client = OpenAI(api_key="xxx", base_url=base_url)
+ def __init__(self, key,model_name="whisper-small",lang="Chinese",**kwargs):
+ self.base_url = kwargs.get('base_url', None)
self.model_name = model_name
+ def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
+ if isinstance(audio, str):
+ audio_file = open(audio, 'rb')
+ audio_data = audio_file.read()
+ audio_file_name = audio.split("/")[-1]
+ else:
+ audio_data = audio
+ audio_file_name = "audio.wav"
+
+ payload = {
+ "model": self.model_name,
+ "language": language,
+ "prompt": prompt,
+ "response_format": response_format,
+ "temperature": temperature
+ }
+
+ files = {
+ "file": (audio_file_name, audio_data, 'audio/wav')
+ }
+
+ try:
+ response = requests.post(
+ f"{self.base_url}/v1/audio/transcriptions",
+ files=files,
+ data=payload
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ if 'text' in result:
+ transcription_text = result['text'].strip()
+ return transcription_text, num_tokens_from_string(transcription_text)
+ else:
+ return "**ERROR**: Failed to retrieve transcription.", 0
+
+ except requests.exceptions.RequestException as e:
+ return f"**ERROR**: {str(e)}", 0
+
+ def audio2base64(self, audio):
+ if isinstance(audio, bytes):
+ return base64.b64encode(audio).decode("utf-8")
+ if isinstance(audio, io.BytesIO):
+ return base64.b64encode(audio.getvalue()).decode("utf-8")
+ raise TypeError("The input audio file should be in binary format.")
+
class TencentCloudSeq2txt(Base):
def __init__(
- self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
+ self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
):
from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client
diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py
index bfdb8762c8..7131068a48 100644
--- a/rag/llm/tts_model.py
+++ b/rag/llm/tts_model.py
@@ -297,3 +297,63 @@ def run(*args):
break
status_code = 1
yield audio_chunk
+
+
+
+
+class XinferenceTTS:
+ def __init__(self, key, model_name, **kwargs):
+ self.base_url = kwargs.get("base_url", None)
+ self.model_name = model_name
+ self.headers = {
+ "accept": "application/json",
+ "Content-Type": "application/json"
+ }
+
+ def tts(self, text, voice="中文女", stream=True):
+ payload = {
+ "model": self.model_name,
+ "input": text,
+ "voice": voice
+ }
+
+ response = requests.post(
+ f"{self.base_url}/v1/audio/speech",
+ headers=self.headers,
+ json=payload,
+ stream=stream
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"**Error**: {response.status_code}, {response.text}")
+
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ yield chunk
+
+
+class OllamaTTS(Base):
+ def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
+ if not base_url: base_url = "https://api.ollama.ai/v1"
+ self.model_name = model_name
+ self.base_url = base_url
+ self.headers = {
+ "Content-Type": "application/json"
+ }
+
+ def tts(self, text, voice="standard-voice"):
+ payload = {
+ "model": self.model_name,
+ "voice": voice,
+ "input": text
+ }
+
+ response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
+
+ if response.status_code != 200:
+ raise Exception(f"**Error**: {response.status_code}, {response.text}")
+
+ for chunk in response.iter_content():
+ if chunk:
+ yield chunk
+
diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
index cb9bd2546e..d4bb385b60 100644
--- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
+++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
@@ -93,8 +93,8 @@ const OllamaModal = ({
-
-
+
+
>
)}