Skip to content

Commit

Permalink
Merge branch 'main' into feature/elasticsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
kkrishTa authored Dec 9, 2024
2 parents 36bdd51 + fbf3017 commit 4d029f9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
4 changes: 2 additions & 2 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def generate():
)
token_start = time.time()
else:
yield chunk
token_start = self.metrics.token_update(token_start, is_first)
yield chunk
is_first = False
self.metrics.request_update(req_start)
self.metrics.pending_update(False)
Expand Down Expand Up @@ -306,7 +306,7 @@ def token_generator(self, sentence: str, token_start: float, is_first: bool, is_
suffix = "\n\n"
tokens = re.findall(r"\s?\S+\s?", sentence, re.UNICODE)
for token in tokens:
yield prefix + repr(token.replace("\\n", "\n").encode("utf-8")) + suffix
token_start = self.metrics.token_update(token_start, is_first)
yield prefix + repr(token.replace("\\n", "\n").encode("utf-8")) + suffix
if is_last:
yield "data: [DONE]\n\n"
38 changes: 31 additions & 7 deletions comps/embeddings/mosec/langchain/embedding_mosec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional, Union

from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.embeddings.openai import async_embed_with_retry

from comps import (
CustomLogger,
Expand Down Expand Up @@ -35,7 +36,7 @@ async def _aget_len_safe_embeddings(
) -> List[List[float]]:
_chunk_size = chunk_size or self.chunk_size
batched_embeddings: List[List[float]] = []
response = self.client.create(input=texts, **self._invocation_params)
response = await async_embed_with_retry(self, input=texts, **self._invocation_params)
if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])
Expand All @@ -45,7 +46,7 @@ async def _aget_len_safe_embeddings(
async def empty_embedding() -> List[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(input="", **self._invocation_params)
average_embedded = await async_embed_with_retry(self, input="", **self._invocation_params)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
Expand All @@ -57,6 +58,29 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings])
return embeddings

def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
_chunk_size = chunk_size or self.chunk_size
batched_embeddings: List[List[float]] = []
response = self.client.create(input=texts, **self._invocation_params)
if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])

_cached_empty_embedding: Optional[List[float]] = None

def empty_embedding() -> List[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(input="", **self._invocation_params)
if not isinstance(average_embedded, dict):
average_embedded = average_embedded.model_dump()
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
return _cached_empty_embedding

return [e if e is not None else empty_embedding() for e in batched_embeddings]


@register_microservice(
name="opea_service@embedding_mosec",
Expand All @@ -68,18 +92,18 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
output_datatype=EmbedDoc,
)
@register_statistics(names=["opea_service@embedding_mosec"])
async def embedding(
def embedding(
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
if logflag:
logger.info(input)
start = time.time()
if isinstance(input, TextDoc):
embed_vector = await get_embeddings(input.text)
embed_vector = get_embeddings(input.text)
embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector
res = EmbedDoc(text=input.text, embedding=embedding_res)
else:
embed_vector = await get_embeddings(input.input)
embed_vector = get_embeddings(input.input)
if input.dimensions is not None:
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]

Expand All @@ -99,9 +123,9 @@ async def embedding(
return res


async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
texts = [text] if isinstance(text, str) else text
embed_vector = await embeddings.aembed_documents(texts)
embed_vector = embeddings.embed_documents(texts)
return embed_vector


Expand Down

0 comments on commit 4d029f9

Please sign in to comment.