Skip to content

Commit

Permalink
refine xinference (#2521)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#1588

### Type of change

- [x] Refactoring
  • Loading branch information
KevinHuSh authored Sep 20, 2024
1 parent 9bbef82 commit 4a6a2a0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def __init__(self, key, model_name, base_url, lang="Chinese"):

class XinferenceCV(Base):
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name
self.lang = lang
Expand Down
2 changes: 2 additions & 0 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def encode_queries(self, text: str):

class XinferenceEmbed(Base):
def __init__(self, key, model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name

Expand Down
2 changes: 2 additions & 0 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def similarity(self, query: str, texts: list):

class XInferenceRerank(Base):
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.model_name = model_name
self.base_url = base_url
self.headers = {
Expand Down
2 changes: 2 additions & 0 deletions rag/llm/sequence2txt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self, key, model_name, lang="Chinese", **kwargs):

class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name

Expand Down

0 comments on commit 4a6a2a0

Please sign in to comment.