diff --git a/comps/retrievers/opensearch/langchain/requirements.txt b/comps/retrievers/opensearch/langchain/requirements.txt index 829c80e50..7fb987ccb 100644 --- a/comps/retrievers/opensearch/langchain/requirements.txt +++ b/comps/retrievers/opensearch/langchain/requirements.txt @@ -12,3 +12,5 @@ pymupdf sentence_transformers shortuuid uvicorn +pydantic +numpy diff --git a/comps/retrievers/opensearch/langchain/retriever_opensearch.py b/comps/retrievers/opensearch/langchain/retriever_opensearch.py index 3e7bbce94..f5b260230 100644 --- a/comps/retrievers/opensearch/langchain/retriever_opensearch.py +++ b/comps/retrievers/opensearch/langchain/retriever_opensearch.py @@ -3,12 +3,14 @@ import os import time -from typing import Union +import numpy as np +from typing import Union, List, Callable from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import OpenSearchVectorSearch from langchain_huggingface import HuggingFaceEndpointEmbeddings from opensearch_config import EMBED_MODEL, INDEX_NAME, OPENSEARCH_INITIAL_ADMIN_PASSWORD, OPENSEARCH_URL +from pydantic import conlist from comps import ( CustomLogger, @@ -34,6 +36,33 @@ tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT", None) +async def search_all_embeddings_vectors( + embeddings: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]], + func: Callable, + *args, + **kwargs + ): + try: + if not isinstance(embeddings, np.ndarray): + embeddings = np.array(embeddings) + + if not np.issubdtype(embeddings.dtype, np.floating): + raise ValueError("All embeddings values must be floating point numbers") + + if embeddings.ndim == 1: + return await func(embedding=embeddings, *args, **kwargs) + elif embeddings.ndim == 2: + responses = [] + for emb in embeddings: + response = await func(embedding=emb, *args, **kwargs) + responses.extend(response) + return responses + else: + raise ValueError("Embeddings must be one or two dimensional") + except Exception as e: + raise ValueError(f"Embedding data is not valid: {e}") + + @register_microservice( name="opea_service@retriever_opensearch", service_type=ServiceType.RETRIEVER, @@ -65,12 +94,19 @@ async def retrieve( query = input.input # if the OpenSearch index has data, perform the search if input.search_type == "similarity": - search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k) + search_res = await search_all_embeddings_vectors( + embeddings=input.embedding, + func=vector_db.asimilarity_search_by_vector, + k=input.k, + ) elif input.search_type == "similarity_distance_threshold": if input.distance_threshold is None: raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") - search_res = await vector_db.asimilarity_search_by_vector( - embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold + search_res = await search_all_embeddings_vectors( + embeddings=input.embedding, + func=vector_db.asimilarity_search_by_vector, + k=input.k, + distance_threshold=input.distance_threshold, ) elif input.search_type == "similarity_score_threshold": doc_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(