Skip to content

Commit

Permalink
ready for FE
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Aug 31, 2023
1 parent 402cd50 commit 5832e23
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
53 changes: 44 additions & 9 deletions backend/danswer/secondary_llm_flows/query_validation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import re
from collections.abc import Iterator
from dataclasses import asdict

from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line

REASONING_PAT = "REASONING: "
ANSWERABLE_PAT = "ANSWERABLE: "
Expand Down Expand Up @@ -49,24 +53,55 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
return messages


def get_query_answerability(user_query: str) -> tuple[str, bool]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)

def extract_answerability_reasoning(model_raw: str) -> str:
reasoning_match = re.search(
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_output, re.DOTALL
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
)
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
return reasoning_text

answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_output)

def extract_answerability_bool(model_raw: str) -> bool:
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
return answerable


def get_query_answerability(user_query: str) -> tuple[str, bool]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)

return reasoning_text, answerable
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

return reasoning, answerable


def stream_query_answerability(user_query: str) -> Iterator[str]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
return get_default_llm().stream(filled_llm_prompt)
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
for token in tokens:
model_output = model_output + token

if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
remaining = model_output[len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue

if reasoning_pat_found:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))

reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
return
12 changes: 6 additions & 6 deletions backend/danswer/server/search_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from dataclasses import asdict

Expand Down Expand Up @@ -30,13 +29,15 @@
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.models import HelperResponse
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

Expand All @@ -45,10 +46,6 @@
router = APIRouter()


def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"


@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
Expand All @@ -71,7 +68,10 @@ def query_validation(
def stream_query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> StreamingResponse:
pass
query = question.query
return StreamingResponse(
stream_query_answerability(query), media_type="application/json"
)


@router.post("/semantic-search")
Expand Down
5 changes: 5 additions & 0 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import json
from typing import Any


def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"


def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]

Expand Down

0 comments on commit 5832e23

Please sign in to comment.