From 5832e23d7b38abf6843394e6b35a5347c29c0c10 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 31 Aug 2023 13:14:04 -0700 Subject: [PATCH] ready for FE --- .../secondary_llm_flows/query_validation.py | 53 +++++++++++++++---- backend/danswer/server/search_backend.py | 12 ++--- backend/danswer/server/utils.py | 5 ++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index e7fa09ec665..4b536b2f2b6 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -1,8 +1,12 @@ import re from collections.abc import Iterator +from dataclasses import asdict +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt from danswer.llm.build import get_default_llm +from danswer.server.models import QueryValidationResponse +from danswer.server.utils import get_json_line REASONING_PAT = "REASONING: " ANSWERABLE_PAT = "ANSWERABLE: " @@ -49,24 +53,55 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: return messages -def get_query_answerability(user_query: str) -> tuple[str, bool]: - messages = get_query_validation_messages(user_query) - filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) - +def extract_answerability_reasoning(model_raw: str) -> str: reasoning_match = re.search( - f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_output, re.DOTALL + f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL ) reasoning_text = reasoning_match.group(1).strip() if reasoning_match else "" + return reasoning_text - answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_output) + +def extract_answerability_bool(model_raw: str) -> bool: + answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw) answerable_text = answerable_match.group(1).strip() if answerable_match else "" answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False + return answerable + + +def get_query_answerability(user_query: str) -> tuple[str, bool]: + messages = get_query_validation_messages(user_query) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + model_output = get_default_llm().invoke(filled_llm_prompt) - return reasoning_text, answerable + reasoning = extract_answerability_reasoning(model_output) + answerable = extract_answerability_bool(model_output) + + return reasoning, answerable def stream_query_answerability(user_query: str) -> Iterator[str]: messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - return get_default_llm().stream(filled_llm_prompt) + tokens = get_default_llm().stream(filled_llm_prompt) + reasoning_pat_found = False + model_output = "" + for token in tokens: + model_output = model_output + token + + if not reasoning_pat_found and REASONING_PAT in model_output: + reasoning_pat_found = True + remaining = model_output[len(REASONING_PAT) :] + if remaining: + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining))) + continue + + if reasoning_pat_found: + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token))) + + reasoning = extract_answerability_reasoning(model_output) + answerable = extract_answerability_bool(model_output) + + yield get_json_line( + QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict() + ) + return diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 8bf0c01d55e..37950735517 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,4 +1,3 @@ -import json from collections.abc import Generator from dataclasses import asdict @@ -30,6 +29,7 @@ from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import retrieve_ranked_documents from danswer.secondary_llm_flows.query_validation import get_query_answerability +from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.server.models import HelperResponse from danswer.server.models import QAFeedbackRequest from danswer.server.models import QAResponse @@ -37,6 +37,7 @@ from danswer.server.models import QuestionRequest from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchResponse +from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -45,10 +46,6 @@ router = APIRouter() -def get_json_line(json_dict: dict) -> str: - return json.dumps(json_dict) + "\n" - - @router.get("/search-intent") def get_search_type( question: QuestionRequest = Depends(), _: User = Depends(current_user) @@ -71,7 +68,10 @@ def query_validation( def stream_query_validation( question: QuestionRequest = Depends(), _: User = Depends(current_user) ) -> StreamingResponse: - pass + query = question.query + return StreamingResponse( + stream_query_answerability(query), media_type="application/json" + ) @router.post("/semantic-search") diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index f18db93a74f..bf535661878 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -1,6 +1,11 @@ +import json from typing import Any +def get_json_line(json_dict: dict) -> str: + return json.dumps(json_dict) + "\n" + + def mask_string(sensitive_str: str) -> str: return "****...**" + sensitive_str[-4:]