From d2fb3f5130fab1276b6a390d533465eed132b522 Mon Sep 17 00:00:00 2001 From: JobSmithManipulation <143315462+JobSmithManipulation@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:14:20 +0800 Subject: [PATCH] support sequence2txt and tts model in Xinference --- api/db/services/llm_service.py | 2 +- rag/llm/__init__.py | 18 ++- rag/llm/sequence2txt_model.py | 106 +++++++++++++++--- rag/llm/tts_model.py | 60 ++++++++++ .../setting-model/ollama-modal/index.tsx | 4 +- 5 files changed, 164 insertions(+), 26 deletions(-) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 89e5593b36..74515f9951 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -195,7 +195,7 @@ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): self.llm_name = llm_name self.mdl = TenantLLMService.model_instance( tenant_id, llm_type, llm_name, lang=lang) - assert self.mdl, "Can't find mole for {}/{}/{}".format( + assert self.mdl, "Can't find model for {}/{}/{}".format( tenant_id, llm_type, llm_name) self.max_length = 8192 for lm in LLMService.query(llm_name=llm_name): diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 441a2a553b..ac732d021a 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -47,10 +47,9 @@ "Replicate": ReplicateEmbed, "BaiduYiyan": BaiduYiyanEmbed, "Voyage AI": VoyageEmbed, - "HuggingFace":HuggingFaceEmbed, + "HuggingFace": HuggingFaceEmbed, } - CvModel = { "OpenAI": GptV4, "Azure-OpenAI": AzureGptV4, @@ -64,14 +63,13 @@ "LocalAI": LocalAICV, "NVIDIA": NvidiaCV, "LM-Studio": LmStudioCV, - "StepFun":StepFunCV, + "StepFun": StepFunCV, "OpenAI-API-Compatible": OpenAI_APICV, "TogetherAI": TogetherAICV, "01.AI": YiCV, "Tencent Hunyuan": HunyuanCV } - ChatModel = { "OpenAI": GptTurbo, "Azure-OpenAI": AzureChat, @@ -99,7 +97,7 @@ "LeptonAI": LeptonAIChat, "TogetherAI": TogetherAIChat, "PerfXCloud": PerfXCloudChat, - "Upstage":UpstageChat, + "Upstage": UpstageChat, "novita.ai": NovitaAIChat, "SILICONFLOW": SILICONFLOWChat, "01.AI": YiChat, @@ -111,7 +109,6 @@ "Google Cloud": GoogleChat, } - RerankModel = { "BAAI": DefaultRerank, "Jina": JinaRerank, @@ -127,7 +124,6 @@ "Voyage AI": VoyageRerank } - Seq2txtModel = { "OpenAI": GPTSeq2txt, "Tongyi-Qianwen": QWenSeq2txt, @@ -140,6 +136,8 @@ TTSModel = { "Fish Audio": FishAudioTTS, "Tongyi-Qianwen": QwenTTS, - "OpenAI":OpenAITTS, - "XunFei Spark":SparkTTS -} \ No newline at end of file + "OpenAI": OpenAITTS, + "XunFei Spark": SparkTTS, + "Xinference": XinferenceTTS, + "Ollama": OllamaTTS +} diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index a3f7f5af11..c5fca223da 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import requests from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI import io @@ -25,6 +26,7 @@ import base64 import re + class Base(ABC): def __init__(self, key, model_name): pass @@ -36,8 +38,8 @@ def transcription(self, audio, **kwargs): response_format="text" ) return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) - - def audio2base64(self,audio): + + def audio2base64(self, audio): if isinstance(audio, bytes): return base64.b64encode(audio).decode("utf-8") if isinstance(audio, io.BytesIO): @@ -77,31 +79,109 @@ def transcription(self, audio, format): return "**ERROR**: " + result.message, 0 -class OllamaSeq2txt(Base): +class AzureSeq2txt(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): - self.client = Client(host=kwargs["base_url"]) + self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.model_name = model_name self.lang = lang -class AzureSeq2txt(Base): - def __init__(self, key, model_name, lang="Chinese", **kwargs): - self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") +class OllamaSeq2txt(Base): + def __init__(self, base_url="http://localhost:11434", model_name="whisper-small"): + if base_url.split("/")[-1] != "api": + base_url = os.path.join(base_url, "api") + self.base_url = base_url self.model_name = model_name - self.lang = lang + + def transcription(self, audio, format="wav", **kwargs): + b64_audio = self.audio2base64(audio) + + payload = { + "model": self.model_name, + "audio_data": b64_audio, + "format": format, + } + + try: + response = requests.post( + f"{self.base_url}/transcription", + json=payload + ) + response.raise_for_status() + result = response.json() + + if 'text' in result: + transcription_text = result['text'].strip() + return transcription_text, num_tokens_from_string(transcription_text) + else: + return "**ERROR**: Failed to retrieve transcription.", 0 + + except requests.exceptions.RequestException as e: + return f"**ERROR**: {str(e)}", 0 + + def audio2base64(self, audio): + if isinstance(audio, bytes): + return base64.b64encode(audio).decode("utf-8") + if isinstance(audio, io.BytesIO): + return base64.b64encode(audio.getvalue()).decode("utf-8") + raise TypeError("The input audio file should be in binary format.") class XinferenceSeq2txt(Base): - def __init__(self, key, model_name="", base_url=""): - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - self.client = OpenAI(api_key="xxx", base_url=base_url) + def __init__(self, key,model_name="whisper-small",lang="Chinese",**kwargs): + self.base_url = kwargs.get('base_url', None) self.model_name = model_name + def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): + if isinstance(audio, str): + audio_file = open(audio, 'rb') + audio_data = audio_file.read() + audio_file_name = audio.split("/")[-1] + else: + audio_data = audio + audio_file_name = "audio.wav" + + payload = { + "model": self.model_name, + "language": language, + "prompt": prompt, + "response_format": response_format, + "temperature": temperature + } + + files = { + "file": (audio_file_name, audio_data, 'audio/wav') + } + + try: + response = requests.post( + f"{self.base_url}/v1/audio/transcriptions", + files=files, + data=payload + ) + response.raise_for_status() + result = response.json() + + if 'text' in result: + transcription_text = result['text'].strip() + return transcription_text, num_tokens_from_string(transcription_text) + else: + return "**ERROR**: Failed to retrieve transcription.", 0 + + except requests.exceptions.RequestException as e: + return f"**ERROR**: {str(e)}", 0 + + def audio2base64(self, audio): + if isinstance(audio, bytes): + return base64.b64encode(audio).decode("utf-8") + if isinstance(audio, io.BytesIO): + return base64.b64encode(audio.getvalue()).decode("utf-8") + raise TypeError("The input audio file should be in binary format.") + class TencentCloudSeq2txt(Base): def __init__( - self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" + self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" ): from tencentcloud.common import credential from tencentcloud.asr.v20190614 import asr_client diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index bfdb8762c8..7131068a48 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -297,3 +297,63 @@ def run(*args): break status_code = 1 yield audio_chunk + + + + +class XinferenceTTS: + def __init__(self, key, model_name, **kwargs): + self.base_url = kwargs.get("base_url", None) + self.model_name = model_name + self.headers = { + "accept": "application/json", + "Content-Type": "application/json" + } + + def tts(self, text, voice="中文女", stream=True): + payload = { + "model": self.model_name, + "input": text, + "voice": voice + } + + response = requests.post( + f"{self.base_url}/v1/audio/speech", + headers=self.headers, + json=payload, + stream=stream + ) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + + for chunk in response.iter_content(chunk_size=1024): + if chunk: + yield chunk + + +class OllamaTTS(Base): + def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"): + if not base_url: base_url = "https://api.ollama.ai/v1" + self.model_name = model_name + self.base_url = base_url + self.headers = { + "Content-Type": "application/json" + } + + def tts(self, text, voice="standard-voice"): + payload = { + "model": self.model_name, + "voice": voice, + "input": text + } + + response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + + for chunk in response.iter_content(): + if chunk: + yield chunk + diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index cb9bd2546e..d4bb385b60 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -93,8 +93,8 @@ const OllamaModal = ({ - - + + )}