Skip to content

Commit

Permalink
need-verify
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Dec 9, 2024
1 parent dc5d5df commit 860d77a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 83 deletions.
194 changes: 113 additions & 81 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import json
from types import TracebackType
from typing import cast
from typing import Optional

import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import Client as CohereClient
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
Expand Down Expand Up @@ -68,17 +70,22 @@ def __init__(
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.http_client = httpx.AsyncClient(timeout=API_BASED_EMBEDDING_TIMEOUT)

def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
async def _embed_openai(
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL

client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)

final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = client.embeddings.create(input=text_batch, model=model)
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
Expand All @@ -93,19 +100,17 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
logger.error(error_string)
raise RuntimeError(error_string)

def _embed_cohere(
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL

client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
client = CohereAsyncClient(api_key=self.api_key)

final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = client.embed(
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
Expand All @@ -114,7 +119,7 @@ def _embed_cohere(
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings

def _embed_voyage(
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
Expand All @@ -124,28 +129,37 @@ def _embed_voyage(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)

response = client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True,
# Note: Voyage doesn't have an async client yet, we'll need to run this in a thread pool
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True,
),
)
return response.embeddings

def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
# Note: litellm doesn't have async support yet, we'll need to run this in a thread pool
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: embedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
),
)
embeddings = [embedding["embedding"] for embedding in response.data]

return embeddings

def _embed_vertex(
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
Expand All @@ -158,19 +172,23 @@ def _embed_vertex(
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)

embeddings = client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # Also this is default
# Note: Vertex AI doesn't have async client yet, we'll need to run this in a thread pool
embeddings = await asyncio.get_event_loop().run_in_executor(
None,
lambda: client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True,
),
)
return [embedding.values for embedding in embeddings]

def _embed_litellm_proxy(
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
Expand All @@ -183,22 +201,20 @@ def _embed_litellm_proxy(
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)

with httpx.Client() as client:
response = client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
timeout=API_BASED_EMBEDDING_TIMEOUT,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]

@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
async def embed(
self,
*,
texts: list[str],
Expand All @@ -207,19 +223,19 @@ def embed(
deployment_name: str | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
return await self._embed_litellm_proxy(texts, model_name)

embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")

Expand All @@ -233,6 +249,17 @@ def create(
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)

async def __aenter__(self) -> "CloudEmbedding":
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.http_client.aclose()


def get_embedding_model(
model_name: str,
Expand All @@ -242,9 +269,6 @@ def get_embedding_model(

global _GLOBAL_MODELS_DICT # A dictionary to store models

if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}

if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
Expand Down Expand Up @@ -275,7 +299,7 @@ def get_local_reranking_model(


@simple_log_function_time()
def embed_text(
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
Expand Down Expand Up @@ -311,18 +335,18 @@ def embed_text(
"Cloud models take an explicit text type instead."
)

cloud_model = CloudEmbedding(
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)

if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
Expand All @@ -338,8 +362,12 @@ def embed_text(
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
),
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
Expand All @@ -357,27 +385,31 @@ def embed_text(


@simple_log_function_time()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)


def cohere_rerank(
async def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]


def litellm_rerank(
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
Expand Down Expand Up @@ -411,7 +443,7 @@ async def process_embed_request(
else:
prefix = None

embeddings = embed_text(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
Expand Down Expand Up @@ -451,7 +483,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons

try:
if rerank_request.provider_type is None:
sim_scores = local_rerank(
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
Expand All @@ -461,7 +493,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")

sim_scores = litellm_rerank(
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
Expand All @@ -474,7 +506,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank(
sim_scores = await cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
Expand Down
4 changes: 2 additions & 2 deletions backend/model_server/management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@


@router.get("/health")
def healthcheck() -> Response:
async def healthcheck() -> Response:
return Response(status_code=200)


@router.get("/gpu-status")
def gpu_status() -> dict[str, bool | str]:
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():
Expand Down

0 comments on commit 860d77a

Please sign in to comment.