Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM to validate user Query #365

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions backend/danswer/direct_qa/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
Expand Down Expand Up @@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler:

def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
model_version: str = GEN_AI_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs: Any,
) -> QAModel:
if not api_key:
Expand All @@ -79,16 +75,7 @@ def get_default_qa_model(
try:
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(
model=internal_model,
api_key=api_key,
model_version=model_version,
endpoint=endpoint,
model_host_type=model_host_type,
timeout=timeout,
max_output_tokens=max_output_tokens,
**kwargs,
)
llm = get_default_llm()
qa_handler = get_default_qa_handler(model=internal_model)

return QABlock(
Expand Down
35 changes: 4 additions & 31 deletions backend/danswer/direct_qa/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from copy import copy

import tiktoken
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage

from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.interfaces import AnswerQuestionReturn
Expand All @@ -19,32 +16,8 @@
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.llm.llm import LLM


def _dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt


def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import str_prompt_to_langchain_prompt


class QAHandler(abc.ABC):
Expand All @@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _dict_based_prompt_to_langchain_prompt(
return dict_based_prompt_to_langchain_prompt(
JsonChatProcessor.fill_prompt(
question=query, chunks=context_chunks, include_metadata=False
)
Expand All @@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _str_prompt_to_langchain_prompt(
return str_prompt_to_langchain_prompt(
WeakModelFreeformProcessor.fill_prompt(
question=query,
chunks=context_chunks,
Expand Down
22 changes: 21 additions & 1 deletion backend/danswer/llm/build.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
from typing import Any

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.llm.azure import AzureGPT
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT


def get_default_llm(model: str, **kwargs: Any) -> LLM:
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)

raise ValueError(f"Unknown LLM model: {model}")


def get_default_llm(**kwargs: Any) -> LLM:
return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=GEN_AI_API_KEY,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
timeout=QA_TIMEOUT,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs,
)
50 changes: 39 additions & 11 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,65 @@

from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
PromptValue,
)
from langchain.schema import PromptValue
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import BaseMessageChunk
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage

from danswer.configs.app_configs import LOG_LEVEL


def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt


def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]


def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
for message in messages:
yield message.content


def convert_input(input: LanguageModelInput) -> str:
def convert_input(lm_input: LanguageModelInput) -> str:
"""Heavily inspired by:
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
"""
prompt_value = None
if isinstance(input, PromptValue):
prompt_value = input
elif isinstance(input, str):
prompt_value = StringPromptValue(text=input)
elif isinstance(input, list):
prompt_value = ChatPromptValue(messages=input)
if isinstance(lm_input, PromptValue):
prompt_value = lm_input
elif isinstance(lm_input, str):
prompt_value = StringPromptValue(text=lm_input)
elif isinstance(lm_input, list):
prompt_value = ChatPromptValue(messages=lm_input)

if prompt_value is None:
raise ValueError(
f"Invalid input type {type(input)}. "
f"Invalid input type {type(lm_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)

Expand Down
Empty file.
107 changes: 107 additions & 0 deletions backend/danswer/secondary_llm_flows/query_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re
from collections.abc import Iterator
from dataclasses import asdict

from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line

REASONING_PAT = "REASONING: "
ANSWERABLE_PAT = "ANSWERABLE: "
COT_PAT = "\nLet's think step by step"


def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search. Sources contain both up to date and proprietary information for "
f"the specific team. For named or unknown entities, assume the search will always find "
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
},
{"role": "user", "content": "What is this Slack channel about?"},
{
"role": "assistant",
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
},
{
"role": "user",
"content": f"Danswer is unreachable.{COT_PAT}",
},
{
"role": "assistant",
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
{
"role": "assistant",
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": user_query + COT_PAT},
]

return messages


def extract_answerability_reasoning(model_raw: str) -> str:
reasoning_match = re.search(
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
)
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
return reasoning_text


def extract_answerability_bool(model_raw: str) -> bool:
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
return answerable


def get_query_answerability(user_query: str) -> tuple[str, bool]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)

reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

return reasoning, answerable


def stream_query_answerability(user_query: str) -> Iterator[str]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
for token in tokens:
model_output = model_output + token

if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
remaining = model_output[len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue

if reasoning_pat_found:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))

reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
return
5 changes: 5 additions & 0 deletions backend/danswer/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel):
search_feedback: SearchFeedbackType


class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool


class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
Expand Down
28 changes: 23 additions & 5 deletions backend/danswer/server/search_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from dataclasses import asdict

Expand Down Expand Up @@ -29,12 +28,16 @@
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.models import HelperResponse
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

Expand All @@ -43,10 +46,6 @@
router = APIRouter()


def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"


@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
Expand All @@ -56,6 +55,25 @@ def get_search_type(
return recommend_search_flow(query, use_keyword)


@router.get("/query-validation")
def query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> QueryValidationResponse:
query = question.query
reasoning, answerable = get_query_answerability(query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)


@router.get("/stream-query-validation")
def stream_query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> StreamingResponse:
query = question.query
return StreamingResponse(
stream_query_answerability(query), media_type="application/json"
)


@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest,
Expand Down
Loading