diff --git a/comps/lvms/lvm_tgi.py b/comps/lvms/lvm_tgi.py index 9492b4eaf..5fb7d77d2 100644 --- a/comps/lvms/lvm_tgi.py +++ b/comps/lvms/lvm_tgi.py @@ -3,13 +3,17 @@ import os import time +from typing import Union from fastapi.responses import StreamingResponse from huggingface_hub import AsyncInferenceClient +from langchain_core.prompts import PromptTemplate +from template import ChatTemplate from comps import ( CustomLogger, LVMDoc, + LVMSearchedMultimodalDoc, ServiceType, TextDoc, opea_microservices, @@ -32,19 +36,49 @@ output_datatype=TextDoc, ) @register_statistics(names=["opea_service@lvm_tgi"]) -async def lvm(request: LVMDoc): +async def lvm(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: if logflag: logger.info(request) start = time.time() stream_gen_time = [] - img_b64_str = request.image - prompt = request.prompt - max_new_tokens = request.max_new_tokens - streaming = request.streaming - repetition_penalty = request.repetition_penalty - temperature = request.temperature - top_k = request.top_k - top_p = request.top_p + + if isinstance(request, LVMSearchedMultimodalDoc): + if logflag: + logger.info("[LVMSearchedMultimodalDoc ] input from retriever microservice") + retrieved_metadatas = request.metadata + img_b64_str = retrieved_metadatas[0]["b64_img_str"] + initial_query = request.initial_query + context = retrieved_metadatas[0]["transcript_for_inference"] + prompt = initial_query + if request.chat_template is None: + prompt = ChatTemplate.generate_multimodal_rag_on_videos_prompt(initial_query, context) + else: + prompt_template = PromptTemplate.from_template(request.chat_template) + input_variables = prompt_template.input_variables + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=initial_query, context=context) + else: + logger.info( + f"[ LVMSearchedMultimodalDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']" + ) + max_new_tokens = request.max_new_tokens + streaming = request.streaming + repetition_penalty = request.repetition_penalty + temperature = request.temperature + top_k = request.top_k + top_p = request.top_p + if logflag: + logger.info(f"prompt generated for [LVMSearchedMultimodalDoc ] input from retriever microservice: {prompt}") + + else: + img_b64_str = request.image + prompt = request.prompt + max_new_tokens = request.max_new_tokens + streaming = request.streaming + repetition_penalty = request.repetition_penalty + temperature = request.temperature + top_k = request.top_k + top_p = request.top_p image = f"data:image/png;base64,{img_b64_str}" image_prompt = f"![]({image})\n{prompt}\nASSISTANT:" diff --git a/comps/retrievers/langchain/redis_multimodal/multimodal_config.py b/comps/retrievers/langchain/redis_multimodal/multimodal_config.py index 211851b7e..79ac5698f 100644 --- a/comps/retrievers/langchain/redis_multimodal/multimodal_config.py +++ b/comps/retrievers/langchain/redis_multimodal/multimodal_config.py @@ -75,3 +75,9 @@ def format_redis_conn_from_env(): # Vector Index Configuration INDEX_NAME = os.getenv("INDEX_NAME", "test-index") + +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(current_file_path) +REDIS_SCHEMA = os.getenv("REDIS_SCHEMA", "schema.yml") +schema_path = os.path.join(parent_dir, REDIS_SCHEMA) +INDEX_SCHEMA = schema_path diff --git a/comps/retrievers/langchain/redis_multimodal/retriever_redis.py b/comps/retrievers/langchain/redis_multimodal/retriever_redis.py index f998ee28e..c2187c1c7 100644 --- a/comps/retrievers/langchain/redis_multimodal/retriever_redis.py +++ b/comps/retrievers/langchain/redis_multimodal/retriever_redis.py @@ -5,7 +5,7 @@ from typing import Union from langchain_community.vectorstores import Redis -from multimodal_config import INDEX_NAME, REDIS_URL +from multimodal_config import INDEX_NAME, INDEX_SCHEMA, REDIS_URL from comps import ( EmbedMultimodalDoc, @@ -89,5 +89,5 @@ def retrieve( if __name__ == "__main__": embeddings = BridgeTowerEmbedding() - vector_db = Redis(embedding=embeddings, index_name=INDEX_NAME, redis_url=REDIS_URL) + vector_db = Redis(embedding=embeddings, index_name=INDEX_NAME, index_schema=INDEX_SCHEMA, redis_url=REDIS_URL) opea_microservices["opea_service@multimodal_retriever_redis"].start() diff --git a/comps/retrievers/langchain/redis_multimodal/schema.yml b/comps/retrievers/langchain/redis_multimodal/schema.yml new file mode 100644 index 000000000..32f4a79ae --- /dev/null +++ b/comps/retrievers/langchain/redis_multimodal/schema.yml @@ -0,0 +1,19 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +text: + - name: content + - name: b64_img_str + - name: video_id + - name: source_video + - name: embedding_type + - name: title + - name: transcript_for_inference +numeric: + - name: time_of_frame_ms +vector: + - name: content_vector + algorithm: HNSW + datatype: FLOAT32 + dims: 512 + distance_metric: COSINE