Skip to content

Commit

Permalink
Update retriever container image
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Morin <[email protected]>
  • Loading branch information
cameronmorin committed Nov 28, 2024
1 parent 969cd79 commit 242b53e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
2 changes: 2 additions & 0 deletions comps/retrievers/opensearch/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ pymupdf
sentence_transformers
shortuuid
uvicorn
pydantic
numpy
44 changes: 40 additions & 4 deletions comps/retrievers/opensearch/langchain/retriever_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import os
import time
from typing import Union
import numpy as np
from typing import Union, List, Callable

from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from opensearch_config import EMBED_MODEL, INDEX_NAME, OPENSEARCH_INITIAL_ADMIN_PASSWORD, OPENSEARCH_URL
from pydantic import conlist

from comps import (
CustomLogger,
Expand All @@ -34,6 +36,33 @@
tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT", None)


async def search_all_embeddings_vectors(
embeddings: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]],
func: Callable,
*args,
**kwargs
):
try:
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)

if not np.issubdtype(embeddings.dtype, np.floating):
raise ValueError("All embeddings values must be floating point numbers")

if embeddings.ndim == 1:
return await func(embedding=embeddings, *args, **kwargs)
elif embeddings.ndim == 2:
responses = []
for emb in embeddings:
response = await func(embedding=emb, *args, **kwargs)
responses.extend(response)
return responses
else:
raise ValueError("Embeddings must be one or two dimensional")
except Exception as e:
raise ValueError(f"Embedding data is not valid: {e}")


@register_microservice(
name="opea_service@retriever_opensearch",
service_type=ServiceType.RETRIEVER,
Expand Down Expand Up @@ -65,12 +94,19 @@ async def retrieve(
query = input.input
# if the OpenSearch index has data, perform the search
if input.search_type == "similarity":
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k)
search_res = await search_all_embeddings_vectors(
embeddings=input.embedding,
func=vector_db.asimilarity_search_by_vector,
k=input.k,
)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
search_res = await vector_db.asimilarity_search_by_vector(
embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold
search_res = await search_all_embeddings_vectors(
embeddings=input.embedding,
func=vector_db.asimilarity_search_by_vector,
k=input.k,
distance_threshold=input.distance_threshold,
)
elif input.search_type == "similarity_score_threshold":
doc_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(
Expand Down

0 comments on commit 242b53e

Please sign in to comment.