Skip to content

Commit

Permalink
Revert mosec embedding microservice to to use synchronous interface. (#…
Browse files Browse the repository at this point in the history
…971)

* Revert mosec embedding microservice to  to use synchronous interface.

Signed-off-by: Yao, Qing <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add dependency.

Signed-off-by: Yao, Qing <[email protected]>

---------

Signed-off-by: Yao, Qing <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yao531441 and pre-commit-ci[bot] authored Dec 6, 2024
1 parent 5663e16 commit fbf3017
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions comps/embeddings/mosec/langchain/embedding_mosec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional, Union

from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.embeddings.openai import async_embed_with_retry

from comps import (
CustomLogger,
Expand Down Expand Up @@ -35,7 +36,7 @@ async def _aget_len_safe_embeddings(
) -> List[List[float]]:
_chunk_size = chunk_size or self.chunk_size
batched_embeddings: List[List[float]] = []
response = self.client.create(input=texts, **self._invocation_params)
response = await async_embed_with_retry(self, input=texts, **self._invocation_params)
if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])
Expand All @@ -45,7 +46,7 @@ async def _aget_len_safe_embeddings(
async def empty_embedding() -> List[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(input="", **self._invocation_params)
average_embedded = await async_embed_with_retry(self, input="", **self._invocation_params)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
Expand All @@ -57,6 +58,29 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings])
return embeddings

def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
_chunk_size = chunk_size or self.chunk_size
batched_embeddings: List[List[float]] = []
response = self.client.create(input=texts, **self._invocation_params)
if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])

_cached_empty_embedding: Optional[List[float]] = None

def empty_embedding() -> List[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(input="", **self._invocation_params)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
return _cached_empty_embedding

return [e if e is not None else empty_embedding() for e in batched_embeddings]


@register_microservice(
name="opea_service@embedding_mosec",
Expand All @@ -68,18 +92,18 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
output_datatype=EmbedDoc,
)
@register_statistics(names=["opea_service@embedding_mosec"])
async def embedding(
def embedding(
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
if logflag:
logger.info(input)
start = time.time()
if isinstance(input, TextDoc):
embed_vector = await get_embeddings(input.text)
embed_vector = get_embeddings(input.text)
embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector
res = EmbedDoc(text=input.text, embedding=embedding_res)
else:
embed_vector = await get_embeddings(input.input)
embed_vector = get_embeddings(input.input)
if input.dimensions is not None:
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]

Expand All @@ -99,9 +123,9 @@ async def embedding(
return res


async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
texts = [text] if isinstance(text, str) else text
embed_vector = await embeddings.aembed_documents(texts)
embed_vector = embeddings.embed_documents(texts)
return embed_vector


Expand Down

0 comments on commit fbf3017

Please sign in to comment.