From 26530af3658a6c5999df2b6beaff42977874e0fa Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Mon, 12 Aug 2024 23:24:15 +0800 Subject: [PATCH] refact embedding/ranking/llm request/response by referring to openai format (#405) Co-authored-by: sys-lpot-val Co-authored-by: lvliang-intel --- comps/cores/mega/gateway.py | 1 + comps/cores/proto/api_protocol.py | 268 +++++++++++++++--- comps/cores/proto/docarray.py | 38 ++- comps/embeddings/langchain/embedding_tei.py | 32 ++- comps/llms/text-generation/tgi/README.md | 6 + comps/llms/text-generation/tgi/llm.py | 180 +++++++++--- .../llms/text-generation/tgi/requirements.txt | 1 + comps/llms/text-generation/tgi/template.py | 29 ++ comps/reranks/tei/reranking_tei.py | 65 +++-- .../langchain/redis/retriever_redis.py | 80 ++++-- tests/test_reranks_tei.sh | 2 +- 11 files changed, 563 insertions(+), 139 deletions(-) create mode 100644 comps/llms/text-generation/tgi/template.py diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 8ad31c8417..6eb069e6ec 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -163,6 +163,7 @@ async def handle_request(self, request: Request): temperature=chat_request.temperature if chat_request.temperature else 0.01, repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, ) result_dict, runtime_graph = await self.megaservice.schedule( initial_inputs={"text": prompt}, llm_parameters=parameters diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 957fc9d951..bd52d72742 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -30,24 +30,243 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class ResponseFormat(BaseModel): + # type must be "json_object" or "text" + type: Literal["text", "json_object"] + + +class StreamOptions(BaseModel): + # refer https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L105 + include_usage: Optional[bool] + + +class FunctionDefinition(BaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(BaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(BaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(BaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class TokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class TokenCheckRequest(BaseModel): + prompts: List[TokenCheckRequestItem] + + +class TokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class TokenCheckResponse(BaseModel): + prompts: List[TokenCheckResponseItem] + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: Optional[str] = None + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field("float", pattern="^(float|base64)$") + dimensions: Optional[int] = None + user: Optional[str] = None + + # define + request_type: Literal["embedding"] = "embedding" + + +class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(BaseModel): + object: str = "list" + model: Optional[str] = None + data: List[EmbeddingResponseData] + usage: Optional[UsageInfo] = None + + +class RetrievalRequest(BaseModel): + embedding: Union[EmbeddingResponse, List[float]] = None + input: Optional[str] = None # search_type maybe need, like "mmr" + search_type: str = "similarity" + k: int = 4 + distance_threshold: Optional[float] = None + fetch_k: int = 20 + lambda_mult: float = 0.5 + score_threshold: float = 0.2 + + # define + request_type: Literal["retrieval"] = "retrieval" + + +class RetrievalResponseData(BaseModel): + text: str + metadata: Optional[Dict[str, Any]] = None + + +class RetrievalResponse(BaseModel): + retrieved_docs: List[RetrievalResponseData] + + +class RerankingRequest(BaseModel): + input: str + retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]], List[str]] + top_n: int = 1 + + # define + request_type: Literal["reranking"] = "reranking" + + +class RerankingResponseData(BaseModel): + text: str + score: Optional[float] = 0.0 + + +class RerankingResponse(BaseModel): + reranked_docs: List[RerankingResponseData] + + class ChatCompletionRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create messages: Union[ str, List[Dict[str, str]], List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], ] model: Optional[str] = "Intel/neural-chat-7b-v3-3" - temperature: Optional[float] = 0.01 - top_p: Optional[float] = 0.95 - top_k: Optional[int] = 10 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = 16 # use https://platform.openai.com/docs/api-reference/completions/create n: Optional[int] = 1 - max_tokens: Optional[int] = 1024 - stop: Optional[Union[str, List[str]]] = None + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = None + service_tier: Optional[str] = None + stop: Union[str, List[str], None] = Field(default_factory=list) stream: Optional[bool] = False - presence_penalty: Optional[float] = 1.03 - frequency_penalty: Optional[float] = 0.0 + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = 1.0 # vllm default 0.7 + top_p: Optional[float] = None # openai default 1.0, but tgi needs `top_p` must be > 0.0 and < 1.0, set None + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none" + parallel_tool_calls: Optional[bool] = True user: Optional[str] = None + # Ordered by official OpenAI API documentation + # default values are same with + # https://platform.openai.com/docs/api-reference/completions/create + best_of: Optional[int] = 1 + suffix: Optional[str] = None + + # vllm reference: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L130 + repetition_penalty: Optional[float] = 1.0 + + # tgi reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate + # some tgi parameters in use + # default values are same with + # https://github.com/huggingface/text-generation-inference/blob/main/router/src/lib.rs#L190 + # max_new_tokens: Optional[int] = 100 # Priority use openai + top_k: Optional[int] = None + # top_p: Optional[float] = None # Priority use openai + typical_p: Optional[float] = None + # repetition_penalty: Optional[float] = None + + # doc: begin-chat-completion-extra-params + echo: Optional[bool] = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " "if they belong to the same role." + ), + ) + add_generation_prompt: Optional[bool] = Field( + default=True, + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), + ) + add_special_tokens: Optional[bool] = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to False (as is the " + "default)." + ), + ) + documents: Optional[Union[List[Dict[str, str]], List[str]]] = Field( + default=None, + description=( + "A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + '"title" and "text" keys.' + ), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead. We recommend that the template contains {context} and {question} for rag," + "or only contains {question} for chat completion without rag." + ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), + ) + # doc: end-chat-completion-extra-params + + # embedding + input: Union[List[int], List[List[int]], str, List[str]] = None # user query/question from messages[-] + encoding_format: Optional[str] = Field("float", pattern="^(float|base64)$") + dimensions: Optional[int] = None + embedding: Union[EmbeddingResponse, List[float]] = Field(default_factory=list) + + # retrieval + search_type: str = "similarity" + k: int = 4 + distance_threshold: Optional[float] = None + fetch_k: int = 20 + lambda_mult: float = 0.5 + score_threshold: float = 0.2 + retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]]] = Field(default_factory=list) + + # reranking + top_n: int = 1 + reranked_docs: Union[List[RerankingResponseData], List[Dict[str, Any]]] = Field(default_factory=list) + + # define + request_type: Literal["chat"] = "chat" + class AudioChatCompletionRequest(BaseModel): audio: str @@ -110,41 +329,6 @@ class ChatCompletionStreamResponse(BaseModel): choices: List[ChatCompletionResponseStreamChoice] -class TokenCheckRequestItem(BaseModel): - model: str - prompt: str - max_tokens: int - - -class TokenCheckRequest(BaseModel): - prompts: List[TokenCheckRequestItem] - - -class TokenCheckResponseItem(BaseModel): - fits: bool - tokenCount: int - contextLength: int - - -class TokenCheckResponse(BaseModel): - prompts: List[TokenCheckResponseItem] - - -class EmbeddingsRequest(BaseModel): - model: Optional[str] = None - engine: Optional[str] = None - input: Union[str, List[Any]] - user: Optional[str] = None - encoding_format: Optional[str] = None - - -class EmbeddingsResponse(BaseModel): - object: str = "list" - data: List[Dict[str, Any]] - model: str - usage: UsageInfo - - class CompletionRequest(BaseModel): model: str prompt: Union[str, List[Any]] diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 9e07d618d1..9760d7d3ec 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -1,13 +1,13 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Dict, List, Optional, Union import numpy as np from docarray import BaseDoc, DocList from docarray.documents import AudioDoc from docarray.typing import AudioUrl -from pydantic import Field, conint, conlist +from pydantic import Field, conint, conlist, field_validator class TopologyInfo: @@ -88,6 +88,30 @@ class LLMParamsDoc(BaseDoc): repetition_penalty: float = 1.03 streaming: bool = True + chat_template: Optional[str] = Field( + default=None, + description=( + "A template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead. We recommend that the template contains {context} and {question} for rag," + "or only contains {question} for chat completion without rag." + ), + ) + documents: Optional[Union[List[Dict[str, str]], List[str]]] = Field( + default=[], + description=( + "A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + '"title" and "text" keys.' + ), + ) + + @field_validator("chat_template") + def chat_template_must_contain_variables(cls, v): + return v + class LLMParams(BaseDoc): max_new_tokens: int = 1024 @@ -98,6 +122,16 @@ class LLMParams(BaseDoc): repetition_penalty: float = 1.03 streaming: bool = True + chat_template: Optional[str] = Field( + default=None, + description=( + "A template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead. We recommend that the template contains {context} and {question} for rag," + "or only contains {question} for chat completion without rag." + ), + ) + class RAGASParams(BaseDoc): questions: DocList[TextDoc] diff --git a/comps/embeddings/langchain/embedding_tei.py b/comps/embeddings/langchain/embedding_tei.py index 4c482db51c..583e24d5a0 100644 --- a/comps/embeddings/langchain/embedding_tei.py +++ b/comps/embeddings/langchain/embedding_tei.py @@ -3,6 +3,7 @@ import os import time +from typing import Union from langchain_community.embeddings import HuggingFaceHubEmbeddings from langsmith import traceable @@ -16,6 +17,12 @@ register_statistics, statistics_dict, ) +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, +) @register_microservice( @@ -24,15 +31,30 @@ endpoint="/v1/embeddings", host="0.0.0.0", port=6000, - input_datatype=TextDoc, - output_datatype=EmbedDoc, ) @traceable(run_type="embedding") @register_statistics(names=["opea_service@embedding_tei_langchain"]) -def embedding(input: TextDoc) -> EmbedDoc: +def embedding( + input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest] +) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]: start = time.time() - embed_vector = embeddings.embed_query(input.text) - res = EmbedDoc(text=input.text, embedding=embed_vector) + + if isinstance(input, TextDoc): + embed_vector = embeddings.embed_query(input.text) + res = EmbedDoc(text=input.text, embedding=embed_vector) + else: + embed_vector = embeddings.embed_query(input.input) + if input.dimensions is not None: + embed_vector = embed_vector[: input.dimensions] + + if isinstance(input, ChatCompletionRequest): + input.embedding = embed_vector + # keep + res = input + if isinstance(input, EmbeddingRequest): + # for standard openai embedding format + res = EmbeddingResponse(data=[EmbeddingResponseData(index=0, embedding=embed_vector)]) + statistics_dict["opea_service@embedding_tei_langchain"].append_latency(time.time() - start, None) return res diff --git a/comps/llms/text-generation/tgi/README.md b/comps/llms/text-generation/tgi/README.md index 6c9607ca96..57f4767208 100644 --- a/comps/llms/text-generation/tgi/README.md +++ b/comps/llms/text-generation/tgi/README.md @@ -110,6 +110,12 @@ curl http://${your_ip}:9000/v1/chat/completions \ -X POST \ -d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \ -H 'Content-Type: application/json' + +# custom chat template +curl http://${your_ip}:9000/v1/chat/completions \ + -X POST \ + -d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \ + -H 'Content-Type: application/json' ``` ## 4. Validated Model diff --git a/comps/llms/text-generation/tgi/llm.py b/comps/llms/text-generation/tgi/llm.py index e267c21dc2..c202aede7a 100644 --- a/comps/llms/text-generation/tgi/llm.py +++ b/comps/llms/text-generation/tgi/llm.py @@ -3,10 +3,14 @@ import os import time +from typing import Union from fastapi.responses import StreamingResponse from huggingface_hub import AsyncInferenceClient +from langchain_core.prompts import PromptTemplate from langsmith import traceable +from openai import OpenAI +from template import ChatTemplate from comps import ( GeneratedDoc, @@ -17,6 +21,13 @@ register_statistics, statistics_dict, ) +from comps.cores.proto.api_protocol import ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse + +llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") +llm = AsyncInferenceClient( + model=llm_endpoint, + timeout=600, +) @register_microservice( @@ -28,36 +39,32 @@ ) @traceable(run_type="llm") @register_statistics(names=["opea_service@llm_tgi"]) -async def llm_generate(input: LLMParamsDoc): +async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): + + prompt_template = None + if input.chat_template: + prompt_template = PromptTemplate.from_template(input.chat_template) + input_variables = prompt_template.input_variables + stream_gen_time = [] start = time.time() - if input.streaming: - - async def stream_generator(): - chat_response = "" - text_generation = await llm.text_generation( - prompt=input.query, - stream=input.streaming, - max_new_tokens=input.max_new_tokens, - repetition_penalty=input.repetition_penalty, - temperature=input.temperature, - top_k=input.top_k, - top_p=input.top_p, - ) - async for text in text_generation: - stream_gen_time.append(time.time() - start) - chat_response += text - chunk_repr = repr(text.encode("utf-8")) - print(f"[llm - chat_stream] chunk:{chunk_repr}") - yield f"data: {chunk_repr}\n\n" - print(f"[llm - chat_stream] stream response: {chat_response}") - statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0]) - yield "data: [DONE]\n\n" - - return StreamingResponse(stream_generator(), media_type="text/event-stream") - else: - response = await llm.text_generation( - prompt=input.query, + + if isinstance(input, LLMParamsDoc): + prompt = input.query + if prompt_template: + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=input.query, context="\n".join(input.documents)) + elif input_variables == ["question"]: + prompt = prompt_template.format(question=input.query) + else: + print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + else: + if input.documents: + # use rag default template + prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents) + + text_generation = await llm.text_generation( + prompt=prompt, stream=input.streaming, max_new_tokens=input.max_new_tokens, repetition_penalty=input.repetition_penalty, @@ -65,14 +72,117 @@ async def stream_generator(): top_k=input.top_k, top_p=input.top_p, ) - statistics_dict["opea_service@llm_tgi"].append_latency(time.time() - start, None) - return GeneratedDoc(text=response, prompt=input.query) + if input.streaming: + + async def stream_generator(): + chat_response = "" + async for text in text_generation: + stream_gen_time.append(time.time() - start) + chat_response += text + chunk_repr = repr(text.encode("utf-8")) + print(f"[llm - chat_stream] chunk:{chunk_repr}") + yield f"data: {chunk_repr}\n\n" + print(f"[llm - chat_stream] stream response: {chat_response}") + statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0]) + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + statistics_dict["opea_service@llm_tgi"].append_latency(time.time() - start, None) + return GeneratedDoc(text=text_generation, prompt=input.query) + + else: + client = OpenAI( + api_key="EMPTY", + base_url=llm_endpoint + "/v1", + ) + + if isinstance(input.messages, str): + prompt = input.messages + if prompt_template: + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=input.messages, context="\n".join(input.documents)) + elif input_variables == ["question"]: + prompt = prompt_template.format(question=input.messages) + else: + print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + else: + if input.documents: + # use rag default template + prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents) + + chat_completion = client.completions.create( + model="tgi", + prompt=prompt, + best_of=input.best_of, + echo=input.echo, + frequency_penalty=input.frequency_penalty, + logit_bias=input.logit_bias, + logprobs=input.logprobs, + max_tokens=input.max_tokens, + n=input.n, + presence_penalty=input.presence_penalty, + seed=input.seed, + stop=input.stop, + stream=input.stream, + suffix=input.suffix, + temperature=input.temperature, + top_p=input.top_p, + user=input.user, + ) + else: + if input.messages[0]["role"] == "system": + if "{context}" in input.messages[0]["content"]: + if input.documents is None or input.documents == []: + input.messages[0]["content"].format(context="") + else: + input.messages[0]["content"].format(context="\n".join(input.documents)) + else: + if prompt_template: + system_prompt = prompt_template + if input_variables == ["context"]: + system_prompt = prompt_template.format(context="\n".join(input.documents)) + else: + print(f"{prompt_template} not used, only support 1 input variables ['context']") + + input.messages.insert(0, {"role": "system", "content": system_prompt}) + + chat_completion = client.chat.completions.create( + model="tgi", + messages=input.messages, + frequency_penalty=input.frequency_penalty, + logit_bias=input.logit_bias, + logprobs=input.logprobs, + top_logprobs=input.top_logprobs, + max_tokens=input.max_tokens, + n=input.n, + presence_penalty=input.presence_penalty, + response_format=input.response_format, + seed=input.seed, + service_tier=input.service_tier, + stop=input.stop, + stream=input.stream, + stream_options=input.stream_options, + temperature=input.temperature, + top_p=input.top_p, + tools=input.tools, + tool_choice=input.tool_choice, + parallel_tool_calls=input.parallel_tool_calls, + user=input.user, + ) + + if input.stream: + + def stream_generator(): + for c in chat_completion: + print(c) + yield f"data: {c.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + return chat_completion if __name__ == "__main__": - llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") - llm = AsyncInferenceClient( - model=llm_endpoint, - timeout=600, - ) opea_microservices["opea_service@llm_tgi"].start() diff --git a/comps/llms/text-generation/tgi/requirements.txt b/comps/llms/text-generation/tgi/requirements.txt index 1e62d477d6..9670813d60 100644 --- a/comps/llms/text-generation/tgi/requirements.txt +++ b/comps/llms/text-generation/tgi/requirements.txt @@ -4,6 +4,7 @@ fastapi httpx huggingface_hub langsmith +openai==1.35.13 opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk diff --git a/comps/llms/text-generation/tgi/template.py b/comps/llms/text-generation/tgi/template.py new file mode 100644 index 0000000000..447efcc673 --- /dev/null +++ b/comps/llms/text-generation/tgi/template.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re + + +class ChatTemplate: + @staticmethod + def generate_rag_prompt(question, documents): + context_str = "\n".join(documents) + if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: + # chinese context + template = """ +### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 +### 搜索结果:{context} +### 问题:{question} +### 回答: +""" + else: + template = """ +### You are a helpful, respectful and honest assistant to help the user with questions. \ +Please refer to the search results obtained from the local knowledge base. \ +But be careful to not incorporate the information that you think is not relevant to the question. \ +If you don't know the answer to a question, please don't share false information. \n +### Search results: {context} \n +### Question: {question} \n +### Answer: +""" + return template.format(context=context_str, question=question) diff --git a/comps/reranks/tei/reranking_tei.py b/comps/reranks/tei/reranking_tei.py index 1beaa83f72..2440f800a7 100644 --- a/comps/reranks/tei/reranking_tei.py +++ b/comps/reranks/tei/reranking_tei.py @@ -6,6 +6,7 @@ import os import re import time +from typing import Union import requests from langsmith import traceable @@ -19,6 +20,12 @@ register_statistics, statistics_dict, ) +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + RerankingRequest, + RerankingResponse, + RerankingResponseData, +) @register_microservice( @@ -32,42 +39,44 @@ ) @traceable(run_type="llm") @register_statistics(names=["opea_service@reranking_tgi_gaudi"]) -def reranking(input: SearchedDoc) -> LLMParamsDoc: +def reranking( + input: Union[SearchedDoc, RerankingRequest, ChatCompletionRequest] +) -> Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest]: + start = time.time() + reranking_results = [] if input.retrieved_docs: docs = [doc.text for doc in input.retrieved_docs] url = tei_reranking_endpoint + "/rerank" - data = {"query": input.initial_query, "texts": docs} + if isinstance(input, SearchedDoc): + query = input.initial_query + else: + # for RerankingRequest, ChatCompletionRequest + query = input.input + data = {"query": query, "texts": docs} headers = {"Content-Type": "application/json"} response = requests.post(url, data=json.dumps(data), headers=headers) response_data = response.json() - best_response_list = heapq.nlargest(input.top_n, response_data, key=lambda x: x["score"]) - context_str = "" - for best_response in best_response_list: - context_str = context_str + " " + input.retrieved_docs[best_response["index"]].text - if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: - # chinese context - template = """ -### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 -### 搜索结果:{context} -### 问题:{question} -### 回答: -""" - else: - template = """ -### You are a helpful, respectful and honest assistant to help the user with questions. \ -Please refer to the search results obtained from the local knowledge base. \ -But be careful to not incorporate the information that you think is not relevant to the question. \ -If you don't know the answer to a question, please don't share false information. \ -### Search results: {context} \n -### Question: {question} \n -### Answer: -""" - final_prompt = template.format(context=context_str, question=input.initial_query) - statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None) - return LLMParamsDoc(query=final_prompt.strip()) + + for best_response in response_data[: input.top_n]: + reranking_results.append( + {"text": input.retrieved_docs[best_response["index"]].text, "score": best_response["score"]} + ) + + statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None) + if isinstance(input, SearchedDoc): + return LLMParamsDoc(query=input.initial_query, documents=[doc["text"] for doc in reranking_results]) else: - return LLMParamsDoc(query=input.initial_query) + reranking_docs = [] + for doc in reranking_results: + reranking_docs.append(RerankingResponseData(text=doc["text"], score=doc["score"])) + if isinstance(input, RerankingRequest): + return RerankingResponse(reranked_docs=reranking_docs) + + if isinstance(input, ChatCompletionRequest): + input.reranked_docs = reranking_docs + input.documents = [doc["text"] for doc in reranking_results] + return input if __name__ == "__main__": diff --git a/comps/retrievers/langchain/redis/retriever_redis.py b/comps/retrievers/langchain/redis/retriever_redis.py index dc4ed01d4c..43f3e0c053 100644 --- a/comps/retrievers/langchain/redis/retriever_redis.py +++ b/comps/retrievers/langchain/redis/retriever_redis.py @@ -3,6 +3,7 @@ import os import time +from typing import Union from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings from langchain_community.vectorstores import Redis @@ -19,6 +20,12 @@ register_statistics, statistics_dict, ) +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + RetrievalRequest, + RetrievalResponse, + RetrievalResponseData, +) tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT") @@ -32,36 +39,57 @@ ) @traceable(run_type="retriever") @register_statistics(names=["opea_service@retriever_redis"]) -def retrieve(input: EmbedDoc) -> SearchedDoc: +def retrieve( + input: Union[EmbedDoc, RetrievalRequest, ChatCompletionRequest] +) -> Union[SearchedDoc, RetrievalResponse, ChatCompletionRequest]: + start = time.time() # check if the Redis index has data if vector_db.client.keys() == []: - result = SearchedDoc(retrieved_docs=[], initial_query=input.text) - statistics_dict["opea_service@retriever_redis"].append_latency(time.time() - start, None) - return result + search_res = [] + else: + if isinstance(input, EmbedDoc): + query = input.text + else: + # for RetrievalRequest, ChatCompletionRequest + query = input.input + # if the Redis index has data, perform the search + if input.search_type == "similarity": + search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k) + elif input.search_type == "similarity_distance_threshold": + if input.distance_threshold is None: + raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") + search_res = vector_db.similarity_search_by_vector( + embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold + ) + elif input.search_type == "similarity_score_threshold": + docs_and_similarities = vector_db.similarity_search_with_relevance_scores( + query=input.text, k=input.k, score_threshold=input.score_threshold + ) + search_res = [doc for doc, _ in docs_and_similarities] + elif input.search_type == "mmr": + search_res = vector_db.max_marginal_relevance_search( + query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult + ) + else: + raise ValueError(f"{input.search_type} not valid") + + # return different response format + retrieved_docs = [] + if isinstance(input, EmbedDoc): + for r in search_res: + retrieved_docs.append(TextDoc(text=r.page_content)) + result = SearchedDoc(retrieved_docs=retrieved_docs, initial_query=input.text) + else: + for r in search_res: + retrieved_docs.append(RetrievalResponseData(text=r.page_content, metadata=r.metadata)) + if isinstance(input, RetrievalRequest): + result = RetrievalResponse(retrieved_docs=retrieved_docs) + elif isinstance(input, ChatCompletionRequest): + input.retrieved_docs = retrieved_docs + input.documents = [doc.text for doc in retrieved_docs] + result = input - # if the Redis index has data, perform the search - if input.search_type == "similarity": - search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k) - elif input.search_type == "similarity_distance_threshold": - if input.distance_threshold is None: - raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") - search_res = vector_db.similarity_search_by_vector( - embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold - ) - elif input.search_type == "similarity_score_threshold": - docs_and_similarities = vector_db.similarity_search_with_relevance_scores( - query=input.text, k=input.k, score_threshold=input.score_threshold - ) - search_res = [doc for doc, _ in docs_and_similarities] - elif input.search_type == "mmr": - search_res = vector_db.max_marginal_relevance_search( - query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult - ) - searched_docs = [] - for r in search_res: - searched_docs.append(TextDoc(text=r.page_content)) - result = SearchedDoc(retrieved_docs=searched_docs, initial_query=input.text) statistics_dict["opea_service@retriever_redis"].append_latency(time.time() - start, None) return result diff --git a/tests/test_reranks_tei.sh b/tests/test_reranks_tei.sh index 0777e7e4dd..4a8c77aadc 100644 --- a/tests/test_reranks_tei.sh +++ b/tests/test_reranks_tei.sh @@ -34,7 +34,7 @@ function validate_microservice() { -d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \ -H 'Content-Type: application/json') - if echo "$CONTENT" | grep -q "### Search results:"; then + if echo "$CONTENT" | grep -q "documents"; then echo "Content is as expected." else echo "Content does not match the expected result: $CONTENT"