diff --git a/comps/embeddings/mosec/langchain/embedding_mosec.py b/comps/embeddings/mosec/langchain/embedding_mosec.py index 38e92b5a7..e422d92b6 100644 --- a/comps/embeddings/mosec/langchain/embedding_mosec.py +++ b/comps/embeddings/mosec/langchain/embedding_mosec.py @@ -7,6 +7,7 @@ from typing import List, Optional, Union from langchain_community.embeddings import OpenAIEmbeddings +from langchain_community.embeddings.openai import async_embed_with_retry from comps import ( CustomLogger, @@ -35,7 +36,7 @@ async def _aget_len_safe_embeddings( ) -> List[List[float]]: _chunk_size = chunk_size or self.chunk_size batched_embeddings: List[List[float]] = [] - response = self.client.create(input=texts, **self._invocation_params) + response = await async_embed_with_retry(self, input=texts, **self._invocation_params) if not isinstance(response, dict): response = response.model_dump() batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -45,7 +46,7 @@ async def _aget_len_safe_embeddings( async def empty_embedding() -> List[float]: nonlocal _cached_empty_embedding if _cached_empty_embedding is None: - average_embedded = self.client.create(input="", **self._invocation_params) + average_embedded = await async_embed_with_retry(self, input="", **self._invocation_params) if not isinstance(average_embedded, dict): average_embedded = average_embedded.model_dump() _cached_empty_embedding = average_embedded["data"][0]["embedding"] @@ -57,6 +58,29 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]: embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings]) return embeddings + def _get_len_safe_embeddings( + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + ) -> List[List[float]]: + _chunk_size = chunk_size or self.chunk_size + batched_embeddings: List[List[float]] = [] + response = self.client.create(input=texts, **self._invocation_params) + if not isinstance(response, dict): + response = response.model_dump() + batched_embeddings.extend(r["embedding"] for r in response["data"]) + + _cached_empty_embedding: Optional[List[float]] = None + + def empty_embedding() -> List[float]: + nonlocal _cached_empty_embedding + if _cached_empty_embedding is None: + average_embedded = self.client.create(input="", **self._invocation_params) + if not isinstance(average_embedded, dict): + average_embedded = average_embedded.model_dump() + _cached_empty_embedding = average_embedded["data"][0]["embedding"] + return _cached_empty_embedding + + return [e if e is not None else empty_embedding() for e in batched_embeddings] + @register_microservice( name="opea_service@embedding_mosec", @@ -68,18 +92,18 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]: output_datatype=EmbedDoc, ) @register_statistics(names=["opea_service@embedding_mosec"]) -async def embedding( +def embedding( input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest] ) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]: if logflag: logger.info(input) start = time.time() if isinstance(input, TextDoc): - embed_vector = await get_embeddings(input.text) + embed_vector = get_embeddings(input.text) embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector res = EmbedDoc(text=input.text, embedding=embedding_res) else: - embed_vector = await get_embeddings(input.input) + embed_vector = get_embeddings(input.input) if input.dimensions is not None: embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))] @@ -99,9 +123,9 @@ async def embedding( return res -async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]: +def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]: texts = [text] if isinstance(text, str) else text - embed_vector = await embeddings.aembed_documents(texts) + embed_vector = embeddings.embed_documents(texts) return embed_vector