Skip to content

Commit

Permalink
support sequence2txt and tts model in Xinference
Browse files Browse the repository at this point in the history
  • Loading branch information
JobSmithManipulation committed Sep 30, 2024
1 parent 5a8ae4a commit d2fb3f5
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 26 deletions.
2 changes: 1 addition & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name):
Expand Down
18 changes: 8 additions & 10 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@
"Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace":HuggingFaceEmbed,
"HuggingFace": HuggingFaceEmbed,
}


CvModel = {
"OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4,
Expand All @@ -64,14 +63,13 @@
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
"StepFun":StepFunCV,
"StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV
}


ChatModel = {
"OpenAI": GptTurbo,
"Azure-OpenAI": AzureChat,
Expand Down Expand Up @@ -99,7 +97,7 @@
"LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat,
"Upstage":UpstageChat,
"Upstage": UpstageChat,
"novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat,
Expand All @@ -111,7 +109,6 @@
"Google Cloud": GoogleChat,
}


RerankModel = {
"BAAI": DefaultRerank,
"Jina": JinaRerank,
Expand All @@ -127,7 +124,6 @@
"Voyage AI": VoyageRerank
}


Seq2txtModel = {
"OpenAI": GPTSeq2txt,
"Tongyi-Qianwen": QWenSeq2txt,
Expand All @@ -140,6 +136,8 @@
TTSModel = {
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS,
"XunFei Spark":SparkTTS
}
"OpenAI": OpenAITTS,
"XunFei Spark": SparkTTS,
"Xinference": XinferenceTTS,
"Ollama": OllamaTTS
}
106 changes: 93 additions & 13 deletions rag/llm/sequence2txt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
Expand All @@ -25,6 +26,7 @@
import base64
import re


class Base(ABC):
def __init__(self, key, model_name):
pass
Expand All @@ -36,8 +38,8 @@ def transcription(self, audio, **kwargs):
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self,audio):

def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
Expand Down Expand Up @@ -77,31 +79,109 @@ def transcription(self, audio, format):
return "**ERROR**: " + result.message, 0


class OllamaSeq2txt(Base):
class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang


class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
class OllamaSeq2txt(Base):
def __init__(self, base_url="http://localhost:11434", model_name="whisper-small"):
if base_url.split("/")[-1] != "api":
base_url = os.path.join(base_url, "api")
self.base_url = base_url
self.model_name = model_name
self.lang = lang

def transcription(self, audio, format="wav", **kwargs):
b64_audio = self.audio2base64(audio)

payload = {
"model": self.model_name,
"audio_data": b64_audio,
"format": format,
}

try:
response = requests.post(
f"{self.base_url}/transcription",
json=payload
)
response.raise_for_status()
result = response.json()

if 'text' in result:
transcription_text = result['text'].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0

except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0

def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")


class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
def __init__(self, key,model_name="whisper-small",lang="Chinese",**kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name

def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, 'rb')
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"

payload = {
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}

files = {
"file": (audio_file_name, audio_data, 'audio/wav')
}

try:
response = requests.post(
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response.raise_for_status()
result = response.json()

if 'text' in result:
transcription_text = result['text'].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0

except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0

def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")


class TencentCloudSeq2txt(Base):
def __init__(
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
):
from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client
Expand Down
60 changes: 60 additions & 0 deletions rag/llm/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,63 @@ def run(*args):
break
status_code = 1
yield audio_chunk




class XinferenceTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json"
}

def tts(self, text, voice="中文女", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}

response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk


class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url: base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json"
}

def tts(self, text, voice="standard-voice"):
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}

response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

for chunk in response.iter_content():
if chunk:
yield chunk

Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ const OllamaModal = ({
<Option value="embedding">embedding</Option>
<Option value="rerank">rerank</Option>
<Option value="image2text">image2text</Option>
<Option value="audio2text">audio2text</Option>
<Option value="text2andio">text2andio</Option>
<Option value="speech2text">sequence2text</Option>
<Option value="tts">tts</Option>
</>
)}
</Select>
Expand Down

0 comments on commit d2fb3f5

Please sign in to comment.