diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index c72be9e4ac3..791bd40f08f 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,4 +1,6 @@ +import asyncio import json +from types import TracebackType from typing import cast from typing import Optional @@ -6,7 +8,7 @@ import openai import vertexai # type: ignore import voyageai # type: ignore -from cohere import Client as CohereClient +from cohere import AsyncClient as CohereAsyncClient from fastapi import APIRouter from fastapi import HTTPException from google.oauth2 import service_account # type: ignore @@ -68,17 +70,22 @@ def __init__( self.api_key = api_key self.api_url = api_url self.api_version = api_version + self.http_client = httpx.AsyncClient(timeout=API_BASED_EMBEDDING_TIMEOUT) - def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: + async def _embed_openai( + self, texts: list[str], model: str | None + ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL - client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) + client = openai.AsyncOpenAI( + api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT + ) final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = client.embeddings.create(input=text_batch, model=model) + response = await client.embeddings.create(input=text_batch, model=model) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -93,19 +100,17 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: logger.error(error_string) raise RuntimeError(error_string) - def _embed_cohere( + async def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_COHERE_MODEL - client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) + client = CohereAsyncClient(api_key=self.api_key) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): - # Does not use the same tokenizer as the Danswer API server but it's approximately the same - # empirically it's only off by a very few tokens so it's not a big deal - response = client.embed( + response = await client.embed( texts=text_batch, model=model, input_type=embedding_type, @@ -114,7 +119,7 @@ def _embed_cohere( final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings - def _embed_voyage( + async def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: @@ -124,28 +129,37 @@ def _embed_voyage( api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT ) - response = client.embed( - texts, - model=model, - input_type=embedding_type, - truncation=True, + # Note: Voyage doesn't have an async client yet, we'll need to run this in a thread pool + response = await asyncio.get_event_loop().run_in_executor( + None, + lambda: client.embed( + texts, + model=model, + input_type=embedding_type, + truncation=True, + ), ) return response.embeddings - def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: - response = embedding( - model=model, - input=texts, - timeout=API_BASED_EMBEDDING_TIMEOUT, - api_key=self.api_key, - api_base=self.api_url, - api_version=self.api_version, + async def _embed_azure( + self, texts: list[str], model: str | None + ) -> list[Embedding]: + # Note: litellm doesn't have async support yet, we'll need to run this in a thread pool + response = await asyncio.get_event_loop().run_in_executor( + None, + lambda: embedding( + model=model, + input=texts, + timeout=API_BASED_EMBEDDING_TIMEOUT, + api_key=self.api_key, + api_base=self.api_url, + api_version=self.api_version, + ), ) embeddings = [embedding["embedding"] for embedding in response.data] - return embeddings - def _embed_vertex( + async def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: @@ -158,19 +172,23 @@ def _embed_vertex( vertexai.init(project=project_id, credentials=credentials) client = TextEmbeddingModel.from_pretrained(model) - embeddings = client.get_embeddings( - [ - TextEmbeddingInput( - text, - embedding_type, - ) - for text in texts - ], - auto_truncate=True, # Also this is default + # Note: Vertex AI doesn't have async client yet, we'll need to run this in a thread pool + embeddings = await asyncio.get_event_loop().run_in_executor( + None, + lambda: client.get_embeddings( + [ + TextEmbeddingInput( + text, + embedding_type, + ) + for text in texts + ], + auto_truncate=True, + ), ) return [embedding.values for embedding in embeddings] - def _embed_litellm_proxy( + async def _embed_litellm_proxy( self, texts: list[str], model_name: str | None ) -> list[Embedding]: if not model_name: @@ -183,22 +201,20 @@ def _embed_litellm_proxy( {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} ) - with httpx.Client() as client: - response = client.post( - self.api_url, - json={ - "model": model_name, - "input": texts, - }, - headers=headers, - timeout=API_BASED_EMBEDDING_TIMEOUT, - ) - response.raise_for_status() - result = response.json() - return [embedding["embedding"] for embedding in result["data"]] + response = await self.http_client.post( + self.api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) - def embed( + async def embed( self, *, texts: list[str], @@ -207,19 +223,19 @@ def embed( deployment_name: str | None = None, ) -> list[Embedding]: if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) + return await self._embed_openai(texts, model_name) elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") + return await self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) + return await self._embed_litellm_proxy(texts, model_name) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) + return await self._embed_cohere(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) + return await self._embed_voyage(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) + return await self._embed_vertex(texts, model_name, embedding_type) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -233,6 +249,17 @@ def create( logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, api_url, api_version) + async def __aenter__(self) -> "CloudEmbedding": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.http_client.aclose() + def get_embedding_model( model_name: str, @@ -242,9 +269,6 @@ def get_embedding_model( global _GLOBAL_MODELS_DICT # A dictionary to store models - if _GLOBAL_MODELS_DICT is None: - _GLOBAL_MODELS_DICT = {} - if model_name not in _GLOBAL_MODELS_DICT: logger.notice(f"Loading {model_name}") # Some model architectures that aren't built into the Transformers or Sentence @@ -275,7 +299,7 @@ def get_local_reranking_model( @simple_log_function_time() -def embed_text( +async def embed_text( texts: list[str], text_type: EmbedTextType, model_name: str | None, @@ -311,18 +335,18 @@ def embed_text( "Cloud models take an explicit text type instead." ) - cloud_model = CloudEmbedding( + async with CloudEmbedding( api_key=api_key, provider=provider_type, api_url=api_url, api_version=api_version, - ) - embeddings = cloud_model.embed( - texts=texts, - model_name=model_name, - deployment_name=deployment_name, - text_type=text_type, - ) + ) as cloud_model: + embeddings = await cloud_model.embed( + texts=texts, + model_name=model_name, + deployment_name=deployment_name, + text_type=text_type, + ) if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" @@ -338,8 +362,12 @@ def embed_text( local_model = get_embedding_model( model_name=model_name, max_context_length=max_context_length ) - embeddings_vectors = local_model.encode( - prefixed_texts, normalize_embeddings=normalize_embeddings + # Run CPU-bound embedding in a thread pool + embeddings_vectors = await asyncio.get_event_loop().run_in_executor( + None, + lambda: local_model.encode( + prefixed_texts, normalize_embeddings=normalize_embeddings + ), ) embeddings = [ embedding if isinstance(embedding, list) else embedding.tolist() @@ -357,27 +385,31 @@ def embed_text( @simple_log_function_time() -def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: +async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: cross_encoder = get_local_reranking_model(model_name) - return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore + # Run CPU-bound reranking in a thread pool + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore + ) -def cohere_rerank( +async def cohere_rerank( query: str, docs: list[str], model_name: str, api_key: str ) -> list[float]: - cohere_client = CohereClient(api_key=api_key) - response = cohere_client.rerank(query=query, documents=docs, model=model_name) + cohere_client = CohereAsyncClient(api_key=api_key) + response = await cohere_client.rerank(query=query, documents=docs, model=model_name) results = response.results sorted_results = sorted(results, key=lambda item: item.index) return [result.relevance_score for result in sorted_results] -def litellm_rerank( +async def litellm_rerank( query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None ) -> list[float]: headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} - with httpx.Client() as client: - response = client.post( + async with httpx.AsyncClient() as client: + response = await client.post( api_url, json={ "model": model_name, @@ -411,7 +443,7 @@ async def process_embed_request( else: prefix = None - embeddings = embed_text( + embeddings = await embed_text( texts=embed_request.texts, model_name=embed_request.model_name, deployment_name=embed_request.deployment_name, @@ -451,7 +483,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons try: if rerank_request.provider_type is None: - sim_scores = local_rerank( + sim_scores = await local_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, @@ -461,7 +493,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons if rerank_request.api_url is None: raise ValueError("API URL is required for LiteLLM reranking.") - sim_scores = litellm_rerank( + sim_scores = await litellm_rerank( query=rerank_request.query, docs=rerank_request.documents, api_url=rerank_request.api_url, @@ -474,7 +506,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") - sim_scores = cohere_rerank( + sim_scores = await cohere_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index 56640a2fa73..4c6387e0708 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -6,12 +6,12 @@ @router.get("/health") -def healthcheck() -> Response: +async def healthcheck() -> Response: return Response(status_code=200) @router.get("/gpu-status") -def gpu_status() -> dict[str, bool | str]: +async def gpu_status() -> dict[str, bool | str]: if torch.cuda.is_available(): return {"gpu_available": True, "type": "cuda"} elif torch.backends.mps.is_available():